summarize as many messages as will fit into the summarizer context

This commit is contained in:
Paul Gauthier 2024-05-06 09:47:14 -07:00
parent e51e0219ee
commit e61857ef09
2 changed files with 23 additions and 10 deletions

View file

@ -295,13 +295,14 @@ class Coder:
max_chat_history_tokens, max_chat_history_tokens,
) )
self.summarizer_thread = None
self.summarized_done_messages = []
if not self.done_messages: if not self.done_messages:
history_md = self.io.read_text(self.io.chat_history_file) history_md = self.io.read_text(self.io.chat_history_file)
if history_md: if history_md:
self.done_messages = self.summarizer.summarize_chat_history_markdown(history_md) self.done_messages = self.summarizer.split_chat_history_markdown(history_md)
self.summarize_start()
self.summarizer_thread = None
self.summarized_done_messages = []
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:

View file

@ -4,7 +4,6 @@ from aider import models, prompts
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries from aider.sendchat import simple_send_with_retries
from tqdm import tqdm
class ChatSummary: class ChatSummary:
def __init__(self, model=None, max_tokens=1024): def __init__(self, model=None, max_tokens=1024):
@ -57,7 +56,21 @@ class ChatSummary:
head = messages[:split_index] head = messages[:split_index]
tail = messages[split_index:] tail = messages[split_index:]
summary = self.summarize_all(head) sized = sized[:split_index]
head.reverse()
sized.reverse()
keep = []
total = 0
model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512
for i in range(split_index):
total += sized[i][0]
if total > model_max_input_tokens:
break
keep.append(head[i])
keep.reverse()
summary = self.summarize_all(keep)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:]) tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = self.token_count(summary) summary_tokens = self.token_count(summary)
@ -91,11 +104,11 @@ class ChatSummary:
return [dict(role="user", content=summary)] return [dict(role="user", content=summary)]
def summarize_chat_history_markdown(self, text): def split_chat_history_markdown(self, text):
messages = [] messages = []
assistant = [] assistant = []
lines = text.splitlines(keepends=True) lines = text.splitlines(keepends=True)
for line in tqdm(lines, desc="Summarizing chat history"): for line in lines:
if line.startswith("# "): if line.startswith("# "):
continue continue
if line.startswith(">"): if line.startswith(">"):
@ -117,8 +130,7 @@ class ChatSummary:
assistant.append(line) assistant.append(line)
summary = self.summarize(messages[-40:]) return messages
return summary
def main(): def main():