Handle all the token/cost corner cases

This commit is contained in:
Paul Gauthier 2024-08-19 14:19:35 -07:00
parent b200bde319
commit 4c2c0ac871

View file

@ -1475,57 +1475,102 @@ class Coder:
def calculate_and_show_tokens_and_cost(self, messages, completion=None):
prompt_tokens = 0
completion_tokens = 0
cached_tokens = 0
cost = 0
cache_hit_tokens = 0
cache_write_tokens = 0
if completion and hasattr(completion, "usage") and completion.usage is not None:
dump(completion.usage)
prompt_tokens = completion.usage.prompt_tokens
completion_tokens = completion.usage.completion_tokens
cached_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
completion.usage, "cache_read_input_tokens", 0
)
cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0)
if hasattr(completion.usage, "cache_read_input_tokens") or hasattr(
completion.usage, "cache_creation_input_tokens"
):
self.message_tokens_sent += prompt_tokens
self.message_tokens_sent += cache_hit_tokens
self.message_tokens_sent += cache_write_tokens
else:
self.message_tokens_sent += prompt_tokens
else:
prompt_tokens = self.main_model.token_count(messages)
completion_tokens = self.main_model.token_count(self.partial_response_content)
self.message_tokens_sent += prompt_tokens
self.message_tokens_sent += prompt_tokens
self.message_tokens_received += completion_tokens
if cached_tokens:
tokens_report = (
f"Tokens: {self.message_tokens_sent:,} sent, {cached_tokens:,} cached, "
f"{self.message_tokens_received:,} received."
)
else:
tokens_report = (
f"Tokens: {self.message_tokens_sent:,} sent,"
f" {self.message_tokens_received:,} received."
)
tokens_report = f"Tokens: {self.message_tokens_sent:,} sent"
if self.main_model.info.get("input_cost_per_token"):
cost += prompt_tokens * self.main_model.info.get("input_cost_per_token")
if self.main_model.info.get("output_cost_per_token"):
cost += completion_tokens * self.main_model.info.get("output_cost_per_token")
self.total_cost += cost
self.message_cost += cost
if cache_write_tokens:
tokens_report += f", {cache_write_tokens:,} cache write"
if cache_hit_tokens:
tokens_report += f", {cache_hit_tokens:,} cache hit"
tokens_report += f", {self.message_tokens_received:,} received."
def format_cost(value):
if value == 0:
return "0.00"
magnitude = abs(value)
if magnitude >= 0.01:
return f"{value:.2f}"
else:
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
cost_report = (
f" Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)
self.usage_report = tokens_report + cost_report
else:
if not self.main_model.info.get("input_cost_per_token"):
self.usage_report = tokens_report
return
cost = 0
input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0
input_cost_per_token_cache_hit = (
self.main_model.info.get("input_cost_per_token_cache_hit") or 0
)
# deepseek
# prompt_cache_hit_tokens + prompt_cache_miss_tokens
# == prompt_tokens == total tokens that were sent
#
# Anthropic
# cache_creation_input_tokens + cache_read_input_tokens + prompt
# == total tokens that were
if input_cost_per_token_cache_hit:
# must be deepseek
cost += input_cost_per_token_cache_hit * cache_hit_tokens
cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token
else:
# hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0
cost += cache_write_tokens * input_cost_per_token * 1.25
cost += cache_hit_tokens * input_cost_per_token * 0.10
cost += prompt_tokens * input_cost_per_token
cost += completion_tokens * output_cost_per_token
self.total_cost += cost
self.message_cost += cost
def format_cost(value):
if value == 0:
return "0.00"
magnitude = abs(value)
if magnitude >= 0.01:
return f"{value:.2f}"
else:
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
cost_report = (
f"Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)
if self.add_cache_headers and self.stream:
warning = " Use --no-stream for accurate caching costs."
self.usage_report = tokens_report + "\n" + cost_report + warning
return
if cache_hit_tokens and cache_write_tokens:
sep = "\n"
else:
sep = " "
self.usage_report = tokens_report + sep + cost_report
def show_usage_report(self):
if self.usage_report: