From 4c2c0ac87152fd2f5737ce7e3a24a4b10b2da37c Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Mon, 19 Aug 2024 14:19:35 -0700 Subject: [PATCH] Handle all the token/cost corner cases --- aider/coders/base_coder.py | 115 ++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 35 deletions(-) diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 198dbcef0..d31809294 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -1475,57 +1475,102 @@ class Coder: def calculate_and_show_tokens_and_cost(self, messages, completion=None): prompt_tokens = 0 completion_tokens = 0 - cached_tokens = 0 - cost = 0 + cache_hit_tokens = 0 + cache_write_tokens = 0 if completion and hasattr(completion, "usage") and completion.usage is not None: dump(completion.usage) prompt_tokens = completion.usage.prompt_tokens completion_tokens = completion.usage.completion_tokens - cached_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr( + cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr( completion.usage, "cache_read_input_tokens", 0 ) + cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0) + + if hasattr(completion.usage, "cache_read_input_tokens") or hasattr( + completion.usage, "cache_creation_input_tokens" + ): + self.message_tokens_sent += prompt_tokens + self.message_tokens_sent += cache_hit_tokens + self.message_tokens_sent += cache_write_tokens + else: + self.message_tokens_sent += prompt_tokens + else: prompt_tokens = self.main_model.token_count(messages) completion_tokens = self.main_model.token_count(self.partial_response_content) + self.message_tokens_sent += prompt_tokens - self.message_tokens_sent += prompt_tokens self.message_tokens_received += completion_tokens - if cached_tokens: - tokens_report = ( - f"Tokens: {self.message_tokens_sent:,} sent, {cached_tokens:,} cached, " - f"{self.message_tokens_received:,} received." - ) - else: - tokens_report = ( - f"Tokens: {self.message_tokens_sent:,} sent," - f" {self.message_tokens_received:,} received." - ) + tokens_report = f"Tokens: {self.message_tokens_sent:,} sent" - if self.main_model.info.get("input_cost_per_token"): - cost += prompt_tokens * self.main_model.info.get("input_cost_per_token") - if self.main_model.info.get("output_cost_per_token"): - cost += completion_tokens * self.main_model.info.get("output_cost_per_token") - self.total_cost += cost - self.message_cost += cost + if cache_write_tokens: + tokens_report += f", {cache_write_tokens:,} cache write" + if cache_hit_tokens: + tokens_report += f", {cache_hit_tokens:,} cache hit" + tokens_report += f", {self.message_tokens_received:,} received." - def format_cost(value): - if value == 0: - return "0.00" - magnitude = abs(value) - if magnitude >= 0.01: - return f"{value:.2f}" - else: - return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}" - - cost_report = ( - f" Cost: ${format_cost(self.message_cost)} message," - f" ${format_cost(self.total_cost)} session." - ) - self.usage_report = tokens_report + cost_report - else: + if not self.main_model.info.get("input_cost_per_token"): self.usage_report = tokens_report + return + + cost = 0 + + input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0 + output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0 + input_cost_per_token_cache_hit = ( + self.main_model.info.get("input_cost_per_token_cache_hit") or 0 + ) + + # deepseek + # prompt_cache_hit_tokens + prompt_cache_miss_tokens + # == prompt_tokens == total tokens that were sent + # + # Anthropic + # cache_creation_input_tokens + cache_read_input_tokens + prompt + # == total tokens that were + + if input_cost_per_token_cache_hit: + # must be deepseek + cost += input_cost_per_token_cache_hit * cache_hit_tokens + cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token + else: + # hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0 + cost += cache_write_tokens * input_cost_per_token * 1.25 + cost += cache_hit_tokens * input_cost_per_token * 0.10 + cost += prompt_tokens * input_cost_per_token + + cost += completion_tokens * output_cost_per_token + + self.total_cost += cost + self.message_cost += cost + + def format_cost(value): + if value == 0: + return "0.00" + magnitude = abs(value) + if magnitude >= 0.01: + return f"{value:.2f}" + else: + return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}" + + cost_report = ( + f"Cost: ${format_cost(self.message_cost)} message," + f" ${format_cost(self.total_cost)} session." + ) + + if self.add_cache_headers and self.stream: + warning = " Use --no-stream for accurate caching costs." + self.usage_report = tokens_report + "\n" + cost_report + warning + return + + if cache_hit_tokens and cache_write_tokens: + sep = "\n" + else: + sep = " " + + self.usage_report = tokens_report + sep + cost_report def show_usage_report(self): if self.usage_report: