mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 17:24:59 +00:00
Handle all the token/cost corner cases
This commit is contained in:
parent
b200bde319
commit
4c2c0ac871
1 changed files with 80 additions and 35 deletions
|
@ -1475,57 +1475,102 @@ class Coder:
|
||||||
def calculate_and_show_tokens_and_cost(self, messages, completion=None):
|
def calculate_and_show_tokens_and_cost(self, messages, completion=None):
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
cached_tokens = 0
|
cache_hit_tokens = 0
|
||||||
cost = 0
|
cache_write_tokens = 0
|
||||||
|
|
||||||
if completion and hasattr(completion, "usage") and completion.usage is not None:
|
if completion and hasattr(completion, "usage") and completion.usage is not None:
|
||||||
dump(completion.usage)
|
dump(completion.usage)
|
||||||
prompt_tokens = completion.usage.prompt_tokens
|
prompt_tokens = completion.usage.prompt_tokens
|
||||||
completion_tokens = completion.usage.completion_tokens
|
completion_tokens = completion.usage.completion_tokens
|
||||||
cached_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
|
cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
|
||||||
completion.usage, "cache_read_input_tokens", 0
|
completion.usage, "cache_read_input_tokens", 0
|
||||||
)
|
)
|
||||||
|
cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0)
|
||||||
|
|
||||||
|
if hasattr(completion.usage, "cache_read_input_tokens") or hasattr(
|
||||||
|
completion.usage, "cache_creation_input_tokens"
|
||||||
|
):
|
||||||
|
self.message_tokens_sent += prompt_tokens
|
||||||
|
self.message_tokens_sent += cache_hit_tokens
|
||||||
|
self.message_tokens_sent += cache_write_tokens
|
||||||
|
else:
|
||||||
|
self.message_tokens_sent += prompt_tokens
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prompt_tokens = self.main_model.token_count(messages)
|
prompt_tokens = self.main_model.token_count(messages)
|
||||||
completion_tokens = self.main_model.token_count(self.partial_response_content)
|
completion_tokens = self.main_model.token_count(self.partial_response_content)
|
||||||
|
self.message_tokens_sent += prompt_tokens
|
||||||
|
|
||||||
self.message_tokens_sent += prompt_tokens
|
|
||||||
self.message_tokens_received += completion_tokens
|
self.message_tokens_received += completion_tokens
|
||||||
|
|
||||||
if cached_tokens:
|
tokens_report = f"Tokens: {self.message_tokens_sent:,} sent"
|
||||||
tokens_report = (
|
|
||||||
f"Tokens: {self.message_tokens_sent:,} sent, {cached_tokens:,} cached, "
|
|
||||||
f"{self.message_tokens_received:,} received."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tokens_report = (
|
|
||||||
f"Tokens: {self.message_tokens_sent:,} sent,"
|
|
||||||
f" {self.message_tokens_received:,} received."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.main_model.info.get("input_cost_per_token"):
|
if cache_write_tokens:
|
||||||
cost += prompt_tokens * self.main_model.info.get("input_cost_per_token")
|
tokens_report += f", {cache_write_tokens:,} cache write"
|
||||||
if self.main_model.info.get("output_cost_per_token"):
|
if cache_hit_tokens:
|
||||||
cost += completion_tokens * self.main_model.info.get("output_cost_per_token")
|
tokens_report += f", {cache_hit_tokens:,} cache hit"
|
||||||
self.total_cost += cost
|
tokens_report += f", {self.message_tokens_received:,} received."
|
||||||
self.message_cost += cost
|
|
||||||
|
|
||||||
def format_cost(value):
|
if not self.main_model.info.get("input_cost_per_token"):
|
||||||
if value == 0:
|
|
||||||
return "0.00"
|
|
||||||
magnitude = abs(value)
|
|
||||||
if magnitude >= 0.01:
|
|
||||||
return f"{value:.2f}"
|
|
||||||
else:
|
|
||||||
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
|
||||||
|
|
||||||
cost_report = (
|
|
||||||
f" Cost: ${format_cost(self.message_cost)} message,"
|
|
||||||
f" ${format_cost(self.total_cost)} session."
|
|
||||||
)
|
|
||||||
self.usage_report = tokens_report + cost_report
|
|
||||||
else:
|
|
||||||
self.usage_report = tokens_report
|
self.usage_report = tokens_report
|
||||||
|
return
|
||||||
|
|
||||||
|
cost = 0
|
||||||
|
|
||||||
|
input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
|
||||||
|
output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0
|
||||||
|
input_cost_per_token_cache_hit = (
|
||||||
|
self.main_model.info.get("input_cost_per_token_cache_hit") or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# deepseek
|
||||||
|
# prompt_cache_hit_tokens + prompt_cache_miss_tokens
|
||||||
|
# == prompt_tokens == total tokens that were sent
|
||||||
|
#
|
||||||
|
# Anthropic
|
||||||
|
# cache_creation_input_tokens + cache_read_input_tokens + prompt
|
||||||
|
# == total tokens that were
|
||||||
|
|
||||||
|
if input_cost_per_token_cache_hit:
|
||||||
|
# must be deepseek
|
||||||
|
cost += input_cost_per_token_cache_hit * cache_hit_tokens
|
||||||
|
cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token
|
||||||
|
else:
|
||||||
|
# hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0
|
||||||
|
cost += cache_write_tokens * input_cost_per_token * 1.25
|
||||||
|
cost += cache_hit_tokens * input_cost_per_token * 0.10
|
||||||
|
cost += prompt_tokens * input_cost_per_token
|
||||||
|
|
||||||
|
cost += completion_tokens * output_cost_per_token
|
||||||
|
|
||||||
|
self.total_cost += cost
|
||||||
|
self.message_cost += cost
|
||||||
|
|
||||||
|
def format_cost(value):
|
||||||
|
if value == 0:
|
||||||
|
return "0.00"
|
||||||
|
magnitude = abs(value)
|
||||||
|
if magnitude >= 0.01:
|
||||||
|
return f"{value:.2f}"
|
||||||
|
else:
|
||||||
|
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
||||||
|
|
||||||
|
cost_report = (
|
||||||
|
f"Cost: ${format_cost(self.message_cost)} message,"
|
||||||
|
f" ${format_cost(self.total_cost)} session."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.add_cache_headers and self.stream:
|
||||||
|
warning = " Use --no-stream for accurate caching costs."
|
||||||
|
self.usage_report = tokens_report + "\n" + cost_report + warning
|
||||||
|
return
|
||||||
|
|
||||||
|
if cache_hit_tokens and cache_write_tokens:
|
||||||
|
sep = "\n"
|
||||||
|
else:
|
||||||
|
sep = " "
|
||||||
|
|
||||||
|
self.usage_report = tokens_report + sep + cost_report
|
||||||
|
|
||||||
def show_usage_report(self):
|
def show_usage_report(self):
|
||||||
if self.usage_report:
|
if self.usage_report:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue