Merge pull request #591 from paul-gauthier/restore-chat-history

Restore prior chat history on launch
This commit is contained in:
paul-gauthier 2024-05-11 07:49:43 -07:00 committed by GitHub
commit 45b2ba8a10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 29 deletions

View file

@ -151,6 +151,12 @@ def get_parser(default_config_files, git_root):
default=1024,
help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)",
)
group.add_argument(
"--max-chat-history-tokens",
type=int,
default=None,
help="Maximum number of tokens to use for chat history. If not specified, uses the model's max_chat_history_tokens.",
)
default_env_file = os.path.join(git_root, ".env") if git_root else ".env"
group.add_argument(
"--env-file",

View file

@ -165,6 +165,9 @@ class Coder:
for fname in self.get_inchat_relative_files():
lines.append(f"Added {fname} to the chat.")
if self.done_messages:
lines.append("Restored previous conversation history.")
return lines
def __init__(
@ -188,6 +191,7 @@ class Coder:
aider_ignore_file=None,
cur_messages=None,
done_messages=None,
max_chat_history_tokens=None,
):
if not fnames:
fnames = []
@ -282,14 +286,22 @@ class Coder:
self.verbose,
)
if max_chat_history_tokens is None:
max_chat_history_tokens = self.main_model.max_chat_history_tokens
self.summarizer = ChatSummary(
self.main_model.weak_model,
self.main_model.max_chat_history_tokens,
max_chat_history_tokens,
)
self.summarizer_thread = None
self.summarized_done_messages = []
if not self.done_messages:
history_md = self.io.read_text(self.io.chat_history_file)
if history_md:
self.done_messages = self.summarizer.split_chat_history_markdown(history_md)
self.summarize_start()
# validate the functions jsonschema
if self.functions:
for function in self.functions:

View file

@ -1,6 +1,6 @@
import argparse
from aider import prompts
from aider import models, prompts
from aider.dump import dump # noqa: F401
from aider.sendchat import simple_send_with_retries
@ -56,7 +56,21 @@ class ChatSummary:
head = messages[:split_index]
tail = messages[split_index:]
summary = self.summarize_all(head)
sized = sized[:split_index]
head.reverse()
sized.reverse()
keep = []
total = 0
model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512
for i in range(split_index):
total += sized[i][0]
if total > model_max_input_tokens:
break
keep.append(head[i])
keep.reverse()
summary = self.summarize_all(keep)
tail_tokens = sum(tokens for tokens, msg in sized[split_index:])
summary_tokens = self.token_count(summary)
@ -90,18 +104,11 @@ class ChatSummary:
return [dict(role="user", content=summary)]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args()
with open(args.filename, "r") as f:
text = f.read()
def split_chat_history_markdown(self, text):
messages = []
assistant = []
for line in text.splitlines(keepends=True):
lines = text.splitlines(keepends=True)
for line in lines:
if line.startswith("# "):
continue
if line.startswith(">"):
@ -123,8 +130,21 @@ def main():
assistant.append(line)
summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False)
summary = summarizer.summarize(messages[-40:])
return messages
def main():
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args()
model = models.Model("gpt-3.5-turbo")
summarizer = ChatSummary(model)
with open(args.filename, "r") as f:
text = f.read()
summary = summarizer.summarize_chat_history_markdown(text)
dump(summary)

View file

@ -330,6 +330,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
use_git=args.git,
voice_language=args.voice_language,
aider_ignore_file=args.aiderignore,
max_chat_history_tokens=args.max_chat_history_tokens,
)
except ValueError as err: