mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 01:04:59 +00:00
try to use litellm.completion_cost
This commit is contained in:
parent
edbfec0ce4
commit
d27bb56cf3
1 changed files with 39 additions and 24 deletions
|
@ -1986,6 +1986,44 @@ class Coder:
|
|||
self.usage_report = tokens_report
|
||||
return
|
||||
|
||||
try:
|
||||
# Try and use litellm's built in cost calculator. Seems to work for non-streaming only?
|
||||
cost = litellm.completion_cost(completion_response=completion)
|
||||
except ValueError:
|
||||
cost = 0
|
||||
|
||||
if not cost:
|
||||
cost = self.compute_costs_from_tokens(
|
||||
prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens
|
||||
)
|
||||
|
||||
self.total_cost += cost
|
||||
self.message_cost += cost
|
||||
|
||||
def format_cost(value):
|
||||
if value == 0:
|
||||
return "0.00"
|
||||
magnitude = abs(value)
|
||||
if magnitude >= 0.01:
|
||||
return f"{value:.2f}"
|
||||
else:
|
||||
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
||||
|
||||
cost_report = (
|
||||
f"Cost: ${format_cost(self.message_cost)} message,"
|
||||
f" ${format_cost(self.total_cost)} session."
|
||||
)
|
||||
|
||||
if cache_hit_tokens and cache_write_tokens:
|
||||
sep = "\n"
|
||||
else:
|
||||
sep = " "
|
||||
|
||||
self.usage_report = tokens_report + sep + cost_report
|
||||
|
||||
def compute_costs_from_tokens(
|
||||
self, prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens
|
||||
):
|
||||
cost = 0
|
||||
|
||||
input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
|
||||
|
@ -2013,30 +2051,7 @@ class Coder:
|
|||
cost += prompt_tokens * input_cost_per_token
|
||||
|
||||
cost += completion_tokens * output_cost_per_token
|
||||
|
||||
self.total_cost += cost
|
||||
self.message_cost += cost
|
||||
|
||||
def format_cost(value):
|
||||
if value == 0:
|
||||
return "0.00"
|
||||
magnitude = abs(value)
|
||||
if magnitude >= 0.01:
|
||||
return f"{value:.2f}"
|
||||
else:
|
||||
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
||||
|
||||
cost_report = (
|
||||
f"Cost: ${format_cost(self.message_cost)} message,"
|
||||
f" ${format_cost(self.total_cost)} session."
|
||||
)
|
||||
|
||||
if cache_hit_tokens and cache_write_tokens:
|
||||
sep = "\n"
|
||||
else:
|
||||
sep = " "
|
||||
|
||||
self.usage_report = tokens_report + sep + cost_report
|
||||
return cost
|
||||
|
||||
def show_usage_report(self):
|
||||
if not self.usage_report:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue