diff --git a/aider/args.py b/aider/args.py index f77831679..9f9c3067f 100644 --- a/aider/args.py +++ b/aider/args.py @@ -151,6 +151,12 @@ def get_parser(default_config_files, git_root): default=1024, help="Max number of tokens to use for repo map, use 0 to disable (default: 1024)", ) + group.add_argument( + "--max-chat-history-tokens", + type=int, + default=None, + help="Maximum number of tokens to use for chat history. If not specified, uses the model's max_chat_history_tokens.", + ) default_env_file = os.path.join(git_root, ".env") if git_root else ".env" group.add_argument( "--env-file", diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 42e9f3531..d336b63a2 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -165,6 +165,9 @@ class Coder: for fname in self.get_inchat_relative_files(): lines.append(f"Added {fname} to the chat.") + if self.done_messages: + lines.append("Restored previous conversation history.") + return lines def __init__( @@ -188,6 +191,7 @@ class Coder: aider_ignore_file=None, cur_messages=None, done_messages=None, + max_chat_history_tokens=None, ): if not fnames: fnames = [] @@ -282,14 +286,22 @@ class Coder: self.verbose, ) + if max_chat_history_tokens is None: + max_chat_history_tokens = self.main_model.max_chat_history_tokens self.summarizer = ChatSummary( self.main_model.weak_model, - self.main_model.max_chat_history_tokens, + max_chat_history_tokens, ) self.summarizer_thread = None self.summarized_done_messages = [] + if not self.done_messages: + history_md = self.io.read_text(self.io.chat_history_file) + if history_md: + self.done_messages = self.summarizer.split_chat_history_markdown(history_md) + self.summarize_start() + # validate the functions jsonschema if self.functions: for function in self.functions: diff --git a/aider/history.py b/aider/history.py index f8ea7cdd4..a93e3141c 100644 --- a/aider/history.py +++ b/aider/history.py @@ -1,6 +1,6 @@ import argparse -from aider import prompts +from aider import models, prompts from aider.dump import dump # noqa: F401 from aider.sendchat import simple_send_with_retries @@ -56,7 +56,21 @@ class ChatSummary: head = messages[:split_index] tail = messages[split_index:] - summary = self.summarize_all(head) + sized = sized[:split_index] + head.reverse() + sized.reverse() + keep = [] + total = 0 + model_max_input_tokens = self.model.info.get("max_input_tokens", 4096) - 512 + for i in range(split_index): + total += sized[i][0] + if total > model_max_input_tokens: + break + keep.append(head[i]) + + keep.reverse() + + summary = self.summarize_all(keep) tail_tokens = sum(tokens for tokens, msg in sized[split_index:]) summary_tokens = self.token_count(summary) @@ -90,41 +104,47 @@ class ChatSummary: return [dict(role="user", content=summary)] + def split_chat_history_markdown(self, text): + messages = [] + assistant = [] + lines = text.splitlines(keepends=True) + for line in lines: + if line.startswith("# "): + continue + if line.startswith(">"): + continue + if line.startswith("#### /"): + continue + + if line.startswith("#### "): + if assistant: + assistant = "".join(assistant) + if assistant.strip(): + messages.append(dict(role="assistant", content=assistant)) + assistant = [] + + content = line[5:] + if content.strip() and content.strip() != "": + messages.append(dict(role="user", content=line[5:])) + continue + + assistant.append(line) + + return messages + def main(): parser = argparse.ArgumentParser() parser.add_argument("filename", help="Markdown file to parse") args = parser.parse_args() + model = models.Model("gpt-3.5-turbo") + summarizer = ChatSummary(model) + with open(args.filename, "r") as f: text = f.read() - messages = [] - assistant = [] - for line in text.splitlines(keepends=True): - if line.startswith("# "): - continue - if line.startswith(">"): - continue - if line.startswith("#### /"): - continue - - if line.startswith("#### "): - if assistant: - assistant = "".join(assistant) - if assistant.strip(): - messages.append(dict(role="assistant", content=assistant)) - assistant = [] - - content = line[5:] - if content.strip() and content.strip() != "": - messages.append(dict(role="user", content=line[5:])) - continue - - assistant.append(line) - - summarizer = ChatSummary("gpt-3.5-turbo", weak_model=False) - summary = summarizer.summarize(messages[-40:]) + summary = summarizer.summarize_chat_history_markdown(text) dump(summary) diff --git a/aider/main.py b/aider/main.py index 48ac549ed..9c7bedab4 100644 --- a/aider/main.py +++ b/aider/main.py @@ -330,6 +330,7 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F use_git=args.git, voice_language=args.voice_language, aider_ignore_file=args.aiderignore, + max_chat_history_tokens=args.max_chat_history_tokens, ) except ValueError as err: