Merge branch 'main' into mixpanel

This commit is contained in:
Paul Gauthier 2024-08-13 07:09:34 -07:00
commit f52265362f
20 changed files with 135 additions and 64 deletions

View file

@ -73,6 +73,9 @@ class Coder:
multi_response_content = ""
partial_response_content = ""
commit_before_message = []
message_cost = 0.0
message_tokens_sent = 0
message_tokens_received = 0
@classmethod
def create(
@ -149,7 +152,10 @@ class Coder:
main_model = self.main_model
weak_model = main_model.weak_model
prefix = "Model:"
output = f" {main_model.name} with {self.edit_format} edit format"
output = f" {main_model.name} with"
if main_model.info.get("supports_assistant_prefill"):
output += " ♾️"
output += f" {self.edit_format} edit format"
if weak_model is not main_model:
prefix = "Models:"
output += f", weak model {weak_model.name}"
@ -993,7 +999,7 @@ class Coder:
return
except FinishReasonLength:
# We hit the output limit!
if not self.main_model.can_prefill:
if not self.main_model.info.get("supports_assistant_prefill"):
exhausted = True
break
@ -1002,7 +1008,9 @@ class Coder:
if messages[-1]["role"] == "assistant":
messages[-1]["content"] = self.multi_response_content
else:
messages.append(dict(role="assistant", content=self.multi_response_content))
messages.append(
dict(role="assistant", content=self.multi_response_content, prefix=True)
)
except Exception as err:
self.io.tool_error(f"Unexpected error: {err}")
traceback.print_exc()
@ -1017,8 +1025,7 @@ class Coder:
self.io.tool_output()
if self.usage_report:
self.io.tool_output(self.usage_report)
self.show_usage_report()
if exhausted:
self.show_exhausted_error()
@ -1241,7 +1248,6 @@ class Coder:
self.io.log_llm_history("TO LLM", format_messages(messages))
interrupted = False
try:
hash_object, completion = send_completion(
model.name,
@ -1258,9 +1264,9 @@ class Coder:
yield from self.show_send_output_stream(completion)
else:
self.show_send_output(completion)
except KeyboardInterrupt:
except KeyboardInterrupt as kbi:
self.keyboard_interrupt()
interrupted = True
raise kbi
finally:
self.io.log_llm_history(
"LLM RESPONSE",
@ -1275,10 +1281,7 @@ class Coder:
if args:
self.io.ai_output(json.dumps(args, indent=4))
if interrupted:
raise KeyboardInterrupt
self.calculate_and_show_tokens_and_cost(messages, completion)
self.calculate_and_show_tokens_and_cost(messages, completion)
def show_send_output(self, completion):
if self.verbose:
@ -1390,13 +1393,19 @@ class Coder:
prompt_tokens = self.main_model.token_count(messages)
completion_tokens = self.main_model.token_count(self.partial_response_content)
self.usage_report = f"Tokens: {prompt_tokens:,} sent, {completion_tokens:,} received."
self.message_tokens_sent += prompt_tokens
self.message_tokens_received += completion_tokens
tokens_report = (
f"Tokens: {self.message_tokens_sent:,} sent, {self.message_tokens_received:,} received."
)
if self.main_model.info.get("input_cost_per_token"):
cost += prompt_tokens * self.main_model.info.get("input_cost_per_token")
if self.main_model.info.get("output_cost_per_token"):
cost += completion_tokens * self.main_model.info.get("output_cost_per_token")
self.total_cost += cost
self.message_cost += cost
def format_cost(value):
if value == 0:
@ -1407,9 +1416,20 @@ class Coder:
else:
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
self.usage_report += (
f" Cost: ${format_cost(cost)} request, ${format_cost(self.total_cost)} session."
cost_report = (
f" Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)
self.usage_report = tokens_report + cost_report
else:
self.usage_report = tokens_report
def show_usage_report(self):
if self.usage_report:
self.io.tool_output(self.usage_report)
self.message_cost = 0.0
self.message_tokens_sent = 0
self.message_tokens_received = 0
self.event(
"message_send",