refactor into Model.token_count()

This commit is contained in:
Paul Gauthier 2023-11-14 10:11:13 -08:00
parent 90f57664a6
commit c24a4a4392
3 changed files with 24 additions and 10 deletions

View file

@ -420,15 +420,9 @@ class Coder:
prompt = prompt.format(fence=self.fence)
return prompt
def send_new_user_message(self, inp):
def format_messages(self):
self.choose_fence()
self.cur_messages += [
dict(role="user", content=inp),
]
main_sys = self.gpt_prompts.main_system
# if self.main_model.max_context_tokens > 4 * 1024:
main_sys += "\n" + self.fmt_system_reminder()
messages = [
@ -440,6 +434,15 @@ class Coder:
messages += self.get_files_messages()
messages += self.cur_messages
return messages
def send_new_user_message(self, inp):
self.cur_messages += [
dict(role="user", content=inp),
]
messages = self.format_messages()
if self.verbose:
utils.show_messages(messages, functions=self.functions)

View file

@ -1,4 +1,3 @@
import json
import re
import subprocess
import sys
@ -109,14 +108,13 @@ class Commands:
dict(role="system", content=self.coder.gpt_prompts.main_system),
dict(role="system", content=self.coder.gpt_prompts.system_reminder),
]
tokens = self.coder.main_model.token_count(json.dumps(msgs))
tokens = self.coder.main_model.token_count(msgs)
res.append((tokens, "system messages", ""))
# chat history
msgs = self.coder.done_messages + self.coder.cur_messages
if msgs:
msgs = [dict(role="dummy", content=msg) for msg in msgs]
msgs = json.dumps(msgs)
tokens = self.coder.main_model.token_count(msgs)
res.append((tokens, "chat history", "use /clear to clear"))

View file

@ -1,3 +1,5 @@
import json
import openai
@ -37,3 +39,14 @@ class Model:
@staticmethod
def commit_message_models():
return [Model.create("gpt-3.5-turbo"), Model.create("gpt-3.5-turbo-16k")]
def token_count(self, messages):
if not self.tokenizer:
return
if type(messages) is str:
msgs = messages
else:
msgs = json.dumps(messages)
return len(self.tokenizer.encode(msgs))