mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 16:54:59 +00:00
Handle all the token/cost corner cases
This commit is contained in:
parent
b200bde319
commit
4c2c0ac871
1 changed files with 80 additions and 35 deletions
|
@ -1475,57 +1475,102 @@ class Coder:
|
|||
def calculate_and_show_tokens_and_cost(self, messages, completion=None):
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
cached_tokens = 0
|
||||
cost = 0
|
||||
cache_hit_tokens = 0
|
||||
cache_write_tokens = 0
|
||||
|
||||
if completion and hasattr(completion, "usage") and completion.usage is not None:
|
||||
dump(completion.usage)
|
||||
prompt_tokens = completion.usage.prompt_tokens
|
||||
completion_tokens = completion.usage.completion_tokens
|
||||
cached_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
|
||||
cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
|
||||
completion.usage, "cache_read_input_tokens", 0
|
||||
)
|
||||
cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0)
|
||||
|
||||
if hasattr(completion.usage, "cache_read_input_tokens") or hasattr(
|
||||
completion.usage, "cache_creation_input_tokens"
|
||||
):
|
||||
self.message_tokens_sent += prompt_tokens
|
||||
self.message_tokens_sent += cache_hit_tokens
|
||||
self.message_tokens_sent += cache_write_tokens
|
||||
else:
|
||||
self.message_tokens_sent += prompt_tokens
|
||||
|
||||
else:
|
||||
prompt_tokens = self.main_model.token_count(messages)
|
||||
completion_tokens = self.main_model.token_count(self.partial_response_content)
|
||||
self.message_tokens_sent += prompt_tokens
|
||||
|
||||
self.message_tokens_sent += prompt_tokens
|
||||
self.message_tokens_received += completion_tokens
|
||||
|
||||
if cached_tokens:
|
||||
tokens_report = (
|
||||
f"Tokens: {self.message_tokens_sent:,} sent, {cached_tokens:,} cached, "
|
||||
f"{self.message_tokens_received:,} received."
|
||||
)
|
||||
else:
|
||||
tokens_report = (
|
||||
f"Tokens: {self.message_tokens_sent:,} sent,"
|
||||
f" {self.message_tokens_received:,} received."
|
||||
)
|
||||
tokens_report = f"Tokens: {self.message_tokens_sent:,} sent"
|
||||
|
||||
if self.main_model.info.get("input_cost_per_token"):
|
||||
cost += prompt_tokens * self.main_model.info.get("input_cost_per_token")
|
||||
if self.main_model.info.get("output_cost_per_token"):
|
||||
cost += completion_tokens * self.main_model.info.get("output_cost_per_token")
|
||||
self.total_cost += cost
|
||||
self.message_cost += cost
|
||||
if cache_write_tokens:
|
||||
tokens_report += f", {cache_write_tokens:,} cache write"
|
||||
if cache_hit_tokens:
|
||||
tokens_report += f", {cache_hit_tokens:,} cache hit"
|
||||
tokens_report += f", {self.message_tokens_received:,} received."
|
||||
|
||||
def format_cost(value):
|
||||
if value == 0:
|
||||
return "0.00"
|
||||
magnitude = abs(value)
|
||||
if magnitude >= 0.01:
|
||||
return f"{value:.2f}"
|
||||
else:
|
||||
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
||||
|
||||
cost_report = (
|
||||
f" Cost: ${format_cost(self.message_cost)} message,"
|
||||
f" ${format_cost(self.total_cost)} session."
|
||||
)
|
||||
self.usage_report = tokens_report + cost_report
|
||||
else:
|
||||
if not self.main_model.info.get("input_cost_per_token"):
|
||||
self.usage_report = tokens_report
|
||||
return
|
||||
|
||||
cost = 0
|
||||
|
||||
input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
|
||||
output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0
|
||||
input_cost_per_token_cache_hit = (
|
||||
self.main_model.info.get("input_cost_per_token_cache_hit") or 0
|
||||
)
|
||||
|
||||
# deepseek
|
||||
# prompt_cache_hit_tokens + prompt_cache_miss_tokens
|
||||
# == prompt_tokens == total tokens that were sent
|
||||
#
|
||||
# Anthropic
|
||||
# cache_creation_input_tokens + cache_read_input_tokens + prompt
|
||||
# == total tokens that were
|
||||
|
||||
if input_cost_per_token_cache_hit:
|
||||
# must be deepseek
|
||||
cost += input_cost_per_token_cache_hit * cache_hit_tokens
|
||||
cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token
|
||||
else:
|
||||
# hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0
|
||||
cost += cache_write_tokens * input_cost_per_token * 1.25
|
||||
cost += cache_hit_tokens * input_cost_per_token * 0.10
|
||||
cost += prompt_tokens * input_cost_per_token
|
||||
|
||||
cost += completion_tokens * output_cost_per_token
|
||||
|
||||
self.total_cost += cost
|
||||
self.message_cost += cost
|
||||
|
||||
def format_cost(value):
|
||||
if value == 0:
|
||||
return "0.00"
|
||||
magnitude = abs(value)
|
||||
if magnitude >= 0.01:
|
||||
return f"{value:.2f}"
|
||||
else:
|
||||
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
||||
|
||||
cost_report = (
|
||||
f"Cost: ${format_cost(self.message_cost)} message,"
|
||||
f" ${format_cost(self.total_cost)} session."
|
||||
)
|
||||
|
||||
if self.add_cache_headers and self.stream:
|
||||
warning = " Use --no-stream for accurate caching costs."
|
||||
self.usage_report = tokens_report + "\n" + cost_report + warning
|
||||
return
|
||||
|
||||
if cache_hit_tokens and cache_write_tokens:
|
||||
sep = "\n"
|
||||
else:
|
||||
sep = " "
|
||||
|
||||
self.usage_report = tokens_report + sep + cost_report
|
||||
|
||||
def show_usage_report(self):
|
||||
if self.usage_report:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue