Merge pull request #591 from paul-gauthier/restore-chat-history

Restore prior chat history on launch
This commit is contained in:
paul-gauthier 2024-05-11 07:49:43 -07:00 committed by GitHub
commit 45b2ba8a10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 29 deletions

View file

@ -151,6 +151,12 @@ def get_parser(default_config_files, git_root):
default=1024, default=1024,
help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)", help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)",
) )
group.add_argument(
"--max-chat-history-tokens",
type=int,
default=None,
help="Maximum number of tokens to use for chat history. If not specified, uses the model's max_chat_history_tokens.",
)
default_env_file = os.path.join(git_root, ".env") if git_root else ".env" default_env_file = os.path.join(git_root, ".env") if git_root else ".env"
group.add_argument( group.add_argument(
"--env-file", "--env-file",

View file

@ -165,6 +165,9 @@ class Coder:
for fname in self.get_inchat_relative_files(): for fname in self.get_inchat_relative_files():
lines.append(f"Added {fname} to the chat.") lines.append(f"Added {fname} to the chat.")
if self.done_messages:
lines.append("Restored previous conversation history.")
return lines return lines
def __init__( def __init__(
@ -188,6 +191,7 @@ class Coder:
aider_ignore_file=None, aider_ignore_file=None,
cur_messages=None, cur_messages=None,
done_messages=None, done_messages=None,
max_chat_history_tokens=None,
): ):
if not fnames: if not fnames:
fnames = [] fnames = []
@ -282,14 +286,22 @@ class Coder:
self.verbose, self.verbose,
) )
if max_chat_history_tokens is None:
max_chat_history_tokens = self.main_model.max_chat_history_tokens
self.summarizer = ChatSummary( self.summarizer = ChatSummary(
self.main_model.weak_model, self.main_model.weak_model,
self.main_model.max_chat_history_tokens, max_chat_history_tokens,
) )
self.summarizer_thread = None self.summarizer_thread = None
self.summarized_done_messages = [] self.summarized_done_messages = []
if not self.done_messages:
history_md = self.io.read_text(self.io.chat_history_file)
if history_md:
self.done_messages = self.summarizer.split_chat_history_markdown(history_md)
self.summarize_start()
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:
for function in self.functions: for function in self.functions:

View file

@ -1,6 +1,6 @@
import argparse import argparse
from aider import prompts from aider import models, prompts
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries from aider.sendchat import simple_send_with_retries
@ -56,7 +56,21 @@ class ChatSummary:
head = messages[:split_index] head = messages[:split_index]
tail = messages[split_index:] tail = messages[split_index:]
summary = self.summarize_all(head) sized = sized[:split_index]
head.reverse()
sized.reverse()
keep = []
total = 0
model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512
for i in range(split_index):
total += sized[i][0]
if total > model_max_input_tokens:
break
keep.append(head[i])
keep.reverse()
summary = self.summarize_all(keep)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:]) tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = self.token_count(summary) summary_tokens = self.token_count(summary)
@ -90,41 +104,47 @@ class ChatSummary:
return [dict(role="user", content=summary)] return [dict(role="user", content=summary)]
def split_chat_history_markdown(self, text):
messages = []
assistant = []
lines = text.splitlines(keepends=True)
for line in lines:
if line.startswith("# "):
continue
if line.startswith(">"):
continue
if line.startswith("#### /"):
continue
if line.startswith("#### "):
if assistant:
assistant = "".join(assistant)
if assistant.strip():
messages.append(dict(role="assistant", content=assistant))
assistant = []
content = line[5:]
if content.strip() and content.strip() != "<blank>":
messages.append(dict(role="user", content=line[5:]))
continue
assistant.append(line)
return messages
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("filename", help="Markdown file to parse") parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args() args = parser.parse_args()
model = models.Model("gpt-3.5-turbo")
summarizer = ChatSummary(model)
with open(args.filename, "r") as f: with open(args.filename, "r") as f:
text = f.read() text = f.read()
messages = [] summary = summarizer.summarize_chat_history_markdown(text)
assistant = []
for line in text.splitlines(keepends=True):
if line.startswith("# "):
continue
if line.startswith(">"):
continue
if line.startswith("#### /"):
continue
if line.startswith("#### "):
if assistant:
assistant = "".join(assistant)
if assistant.strip():
messages.append(dict(role="assistant", content=assistant))
assistant = []
content = line[5:]
if content.strip() and content.strip() != "<blank>":
messages.append(dict(role="user", content=line[5:]))
continue
assistant.append(line)
summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False)
summary = summarizer.summarize(messages[-40:])
dump(summary) dump(summary)

View file

@ -330,6 +330,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
use_git=args.git, use_git=args.git,
voice_language=args.voice_language, voice_language=args.voice_language,
aider_ignore_file=args.aiderignore, aider_ignore_file=args.aiderignore,
max_chat_history_tokens=args.max_chat_history_tokens,
) )
except ValueError as err: except ValueError as err: