refactor into Model.token_count()

This commit is contained in:
Paul Gauthier 2023-11-14 10:11:13 -08:00
parent 90f57664a6
commit c24a4a4392
3 changed files with 24 additions and 10 deletions

View file

@ -420,15 +420,9 @@ class Coder:
prompt = prompt.format(fence=self.fence) prompt = prompt.format(fence=self.fence)
return prompt return prompt
def send_new_user_message(self, inp): def format_messages(self):
self.choose_fence() self.choose_fence()
self.cur_messages += [
dict(role="user", content=inp),
]
main_sys = self.gpt_prompts.main_system main_sys = self.gpt_prompts.main_system
# if self.main_model.max_context_tokens > 4 * 1024:
main_sys += "\n" + self.fmt_system_reminder() main_sys += "\n" + self.fmt_system_reminder()
messages = [ messages = [
@ -440,6 +434,15 @@ class Coder:
messages += self.get_files_messages() messages += self.get_files_messages()
messages += self.cur_messages messages += self.cur_messages
return messages
def send_new_user_message(self, inp):
self.cur_messages += [
dict(role="user", content=inp),
]
messages = self.format_messages()
if self.verbose: if self.verbose:
utils.show_messages(messages, functions=self.functions) utils.show_messages(messages, functions=self.functions)

View file

@ -1,4 +1,3 @@
import json
import re import re
import subprocess import subprocess
import sys import sys
@ -109,14 +108,13 @@ class Commands:
dict(role="system", content=self.coder.gpt_prompts.main_system), dict(role="system", content=self.coder.gpt_prompts.main_system),
dict(role="system", content=self.coder.gpt_prompts.system_reminder), dict(role="system", content=self.coder.gpt_prompts.system_reminder),
] ]
tokens = self.coder.main_model.token_count(json.dumps(msgs)) tokens = self.coder.main_model.token_count(msgs)
res.append((tokens, "system messages", "")) res.append((tokens, "system messages", ""))
# chat history # chat history
msgs = self.coder.done_messages + self.coder.cur_messages msgs = self.coder.done_messages + self.coder.cur_messages
if msgs: if msgs:
msgs = [dict(role="dummy", content=msg) for msg in msgs] msgs = [dict(role="dummy", content=msg) for msg in msgs]
msgs = json.dumps(msgs)
tokens = self.coder.main_model.token_count(msgs) tokens = self.coder.main_model.token_count(msgs)
res.append((tokens, "chat history", "use /clear to clear")) res.append((tokens, "chat history", "use /clear to clear"))

View file

@ -1,3 +1,5 @@
import json
import openai import openai
@ -37,3 +39,14 @@ class Model:
@staticmethod @staticmethod
def commit_message_models(): def commit_message_models():
return [Model.create("gpt-3.5-turbo"), Model.create("gpt-3.5-turbo-16k")] return [Model.create("gpt-3.5-turbo"), Model.create("gpt-3.5-turbo-16k")]
def token_count(self, messages):
if not self.tokenizer:
return
if type(messages) is str:
msgs = messages
else:
msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs))