Added a ChatSummary object to Coder class and used it to summarize chat history.

This commit is contained in:
Paul Gauthier 2023-07-22 10:34:48 -03:00
parent 0d0ac4f61f
commit c26917851f
2 changed files with 5 additions and 3 deletions

View file

@ -20,6 +20,7 @@ from aider.commands import Commands
from aider.repo import GitRepo from aider.repo import GitRepo
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
from aider.history import ChatSummary
from ..dump import dump # noqa: F401 from ..dump import dump # noqa: F401
@ -199,6 +200,8 @@ class Coder:
if self.repo: if self.repo:
self.repo.add_new_files(fnames) self.repo.add_new_files(fnames)
self.summarizer = ChatSummary()
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:
for function in self.functions: for function in self.functions:
@ -353,7 +356,7 @@ class Coder:
def move_back_cur_messages(self, message): def move_back_cur_messages(self, message):
self.done_messages += self.cur_messages self.done_messages += self.cur_messages
#self.done_messages = summarize_chat_history(self.done_messages) self.done_messages = self.summarizer.summarize(self.done_messages)
if message: if message:
self.done_messages += [ self.done_messages += [

View file

@ -9,7 +9,7 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary: class ChatSummary:
def __init__(self, model, max_tokens=1024): def __init__(self, model=models.GPT35.name, max_tokens=1024):
self.tokenizer = tiktoken.encoding_for_model(model) self.tokenizer = tiktoken.encoding_for_model(model)
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -38,7 +38,6 @@ class ChatSummary:
head = messages[:num] head = messages[:num]
tail = messages[num:] tail = messages[num:]
print("=" * 20)
summary = self.summarize_all(head) summary = self.summarize_all(head)
tail_tokens = sum(tokens for tokens, msg in sized[num:]) tail_tokens = sum(tokens for tokens, msg in sized[num:])