mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-30 01:04:59 +00:00
Merge pull request #144 from paul-gauthier/chat-history
Automatically summarize earlier parts of the chat history
This commit is contained in:
commit
4207be9aa1
3 changed files with 182 additions and 0 deletions
|
@ -4,6 +4,7 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from json.decoder import JSONDecodeError
|
||||
|
@ -17,6 +18,7 @@ from rich.markdown import Markdown
|
|||
|
||||
from aider import models, prompts, utils
|
||||
from aider.commands import Commands
|
||||
from aider.history import ChatSummary
|
||||
from aider.repo import GitRepo
|
||||
from aider.repomap import RepoMap
|
||||
from aider.sendchat import send_with_retries
|
||||
|
@ -203,6 +205,10 @@ class Coder:
|
|||
if self.repo:
|
||||
self.repo.add_new_files(fname for fname in fnames if not Path(fname).is_dir())
|
||||
|
||||
self.summarizer = ChatSummary()
|
||||
self.summarizer_thread = None
|
||||
self.summarized_done_messages = None
|
||||
|
||||
# validate the functions jsonschema
|
||||
if self.functions:
|
||||
for function in self.functions:
|
||||
|
@ -355,8 +361,37 @@ class Coder:
|
|||
|
||||
self.last_keyboard_interrupt = now
|
||||
|
||||
def summarize_start(self):
|
||||
if not self.summarizer.too_big(self.done_messages):
|
||||
return
|
||||
|
||||
assert self.summarizer_thread is None
|
||||
assert self.summarized_done_messages is None
|
||||
if self.verbose:
|
||||
self.io.tool_output("Starting to summarize chat history.")
|
||||
|
||||
self.summarizer_thread = threading.Thread(target=self.summarize_worker)
|
||||
self.summarizer_thread.start()
|
||||
|
||||
def summarize_worker(self):
|
||||
self.summarized_done_messages = self.summarizer.summarize(self.done_messages)
|
||||
if self.verbose:
|
||||
self.io.tool_output("Finished summarizing chat history.")
|
||||
|
||||
def summarize_end(self):
|
||||
if self.summarizer_thread is None:
|
||||
return
|
||||
|
||||
self.summarizer_thread.join()
|
||||
self.summarizer_thread = None
|
||||
|
||||
self.done_messages = self.summarized_done_messages
|
||||
self.summarized_done_messages = None
|
||||
|
||||
def move_back_cur_messages(self, message):
|
||||
self.done_messages += self.cur_messages
|
||||
self.summarize_start()
|
||||
|
||||
if message:
|
||||
self.done_messages += [
|
||||
dict(role="user", content=message),
|
||||
|
@ -407,6 +442,7 @@ class Coder:
|
|||
dict(role="system", content=main_sys),
|
||||
]
|
||||
|
||||
self.summarize_end()
|
||||
messages += self.done_messages
|
||||
messages += self.get_files_messages()
|
||||
messages += self.cur_messages
|
||||
|
|
128
aider/history.py
Normal file
128
aider/history.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import argparse
|
||||
import json
|
||||
|
||||
import tiktoken
|
||||
|
||||
from aider import models, prompts
|
||||
from aider.dump import dump # noqa: F401
|
||||
from aider.sendchat import simple_send_with_retries
|
||||
|
||||
|
||||
class ChatSummary:
|
||||
def __init__(self, model=models.GPT35.name, max_tokens=1024):
|
||||
self.tokenizer = tiktoken.encoding_for_model(model)
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def too_big(self, messages):
|
||||
sized = self.tokenize(messages)
|
||||
total = sum(tokens for tokens, _msg in sized)
|
||||
return total > self.max_tokens
|
||||
|
||||
def tokenize(self, messages):
|
||||
sized = []
|
||||
for msg in messages:
|
||||
tokens = len(self.tokenizer.encode(json.dumps(msg)))
|
||||
sized.append((tokens, msg))
|
||||
return sized
|
||||
|
||||
def summarize(self, messages):
|
||||
if len(messages) <= 4:
|
||||
return self.summarize_all(messages)
|
||||
|
||||
sized = self.tokenize(messages)
|
||||
total = sum(tokens for tokens, _msg in sized)
|
||||
if total <= self.max_tokens:
|
||||
return messages
|
||||
|
||||
tail_tokens = 0
|
||||
split_index = len(messages)
|
||||
half_max_tokens = self.max_tokens // 2
|
||||
|
||||
# Iterate over the messages in reverse order
|
||||
for i in range(len(sized) - 1, -1, -1):
|
||||
tokens, _msg = sized[i]
|
||||
if tail_tokens + tokens < half_max_tokens:
|
||||
tail_tokens += tokens
|
||||
split_index = i
|
||||
else:
|
||||
break
|
||||
|
||||
# Ensure the head ends with an assistant message
|
||||
while messages[split_index - 1]["role"] != "assistant" and split_index > 1:
|
||||
split_index -= 1
|
||||
|
||||
head = messages[:split_index]
|
||||
tail = messages[split_index:]
|
||||
|
||||
summary = self.summarize_all(head)
|
||||
|
||||
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
|
||||
summary_tokens = len(self.tokenizer.encode(json.dumps(summary)))
|
||||
|
||||
result = summary + tail
|
||||
if summary_tokens + tail_tokens < self.max_tokens:
|
||||
return result
|
||||
|
||||
return self.summarize(result)
|
||||
|
||||
def summarize_all(self, messages):
|
||||
content = ""
|
||||
for msg in messages:
|
||||
role = msg["role"].upper()
|
||||
if role not in ("USER", "ASSISTANT"):
|
||||
continue
|
||||
content += f"# {role}\n"
|
||||
content += msg["content"]
|
||||
if not content.endswith("\n"):
|
||||
content += "\n"
|
||||
|
||||
messages = [
|
||||
dict(role="system", content=prompts.summarize),
|
||||
dict(role="user", content=content),
|
||||
]
|
||||
|
||||
summary = simple_send_with_retries(model=models.GPT35.name, messages=messages)
|
||||
summary = prompts.summary_prefix + summary
|
||||
|
||||
return [dict(role="user", content=summary)]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filename", help="Markdown file to parse")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
messages = []
|
||||
assistant = []
|
||||
for line in text.splitlines(keepends=True):
|
||||
if line.startswith("# "):
|
||||
continue
|
||||
if line.startswith(">"):
|
||||
continue
|
||||
if line.startswith("#### /"):
|
||||
continue
|
||||
|
||||
if line.startswith("#### "):
|
||||
if assistant:
|
||||
assistant = "".join(assistant)
|
||||
if assistant.strip():
|
||||
messages.append(dict(role="assistant", content=assistant))
|
||||
assistant = []
|
||||
|
||||
content = line[5:]
|
||||
if content.strip() and content.strip() != "<blank>":
|
||||
messages.append(dict(role="user", content=line[5:]))
|
||||
continue
|
||||
|
||||
assistant.append(line)
|
||||
|
||||
summarizer = ChatSummary(models.GPT35.name)
|
||||
summary = summarizer.summarize(messages[-40:])
|
||||
dump(summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -24,3 +24,21 @@ And got this output:
|
|||
|
||||
{output}
|
||||
"""
|
||||
|
||||
# CHAT HISTORY
|
||||
summarize = """*Briefly* summarize this partial conversation about programming.
|
||||
Include less detail about older parts and more detail about the most recent messages.
|
||||
Start a new paragraph every time the topic changes!
|
||||
|
||||
This is only part of a longer conversation so *DO NOT* conclude the summary with language like "Finally, ...". Because the conversation continues after the summary.
|
||||
The summary *MUST* include the function names, libraries, packages that are being discussed.
|
||||
The summary *MUST* include the filenames that are being referenced by the assistant inside the ```...``` fenced code blocks!
|
||||
The summaries *MUST NOT* include ```...``` fenced code blocks!
|
||||
|
||||
Phrase the summary with the USER in first person, telling the ASSISTANT about the conversation.
|
||||
Write *as* the user.
|
||||
The user should refer to the assistant as *you*.
|
||||
Start the summary with "I asked you...".
|
||||
"""
|
||||
|
||||
summary_prefix = "I spoke to you previously about a number of things.\n"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue