summarize as many messages as will fit into the summarizer context

This commit is contained in:
Paul Gauthier 2024-05-06 09:47:14 -07:00
parent e51e0219ee
commit e61857ef09
2 changed files with 23 additions and 10 deletions

View file

@ -295,13 +295,14 @@ class Coder:
max_chat_history_tokens,
)
self.summarizer_thread = None
self.summarized_done_messages = []
if not self.done_messages:
history_md = self.io.read_text(self.io.chat_history_file)
if history_md:
self.done_messages = self.summarizer.summarize_chat_history_markdown(history_md)
self.summarizer_thread = None
self.summarized_done_messages = []
self.done_messages = self.summarizer.split_chat_history_markdown(history_md)
self.summarize_start()
# validate the functions jsonschema
if self.functions:

View file

@ -4,7 +4,6 @@ from aider import models, prompts
from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries
from tqdm import tqdm
class ChatSummary:
def __init__(self, model=None, max_tokens=1024):
@ -57,7 +56,21 @@ class ChatSummary:
head = messages[:split_index]
tail = messages[split_index:]
summary = self.summarize_all(head)
sized = sized[:split_index]
head.reverse()
sized.reverse()
keep = []
total = 0
model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512
for i in range(split_index):
total += sized[i][0]
if total > model_max_input_tokens:
break
keep.append(head[i])
keep.reverse()
summary = self.summarize_all(keep)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = self.token_count(summary)
@ -91,11 +104,11 @@ class ChatSummary:
return [dict(role="user", content=summary)]
def summarize_chat_history_markdown(self, text):
def split_chat_history_markdown(self, text):
messages = []
assistant = []
lines = text.splitlines(keepends=True)
for line in tqdm(lines, desc="Summarizing chat history"):
for line in lines:
if line.startswith("# "):
continue
if line.startswith(">"):
@ -117,8 +130,7 @@ class ChatSummary:
assistant.append(line)
summary = self.summarize(messages[-40:])
return summary
return messages
def main():