From 2e4de1a0d3e30672b31f7cc9729b7c591621641d Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Tue, 30 Jul 2024 11:57:34 -0300 Subject: [PATCH] Add support for multiple models in ChatSummary class --- aider/history.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/aider/history.py b/aider/history.py index 3615235c4..e2fd422bb 100644 --- a/aider/history.py +++ b/aider/history.py @@ -6,10 +6,17 @@ from aider.sendchat import simple_send_with_retries class ChatSummary: - def __init__(self, model=None, max_tokens=1024): - self.token_count = model.token_count + def __init__(self, models=None, max_tokens=1024): + if not models: + raise ValueError("At least one model must be provided") + self.models = models if isinstance(models, list) else [models] self.max_tokens = max_tokens - self.model = model + self.current_model = None + self.set_current_model(self.models[0]) + + def set_current_model(self, model): + self.current_model = model + self.token_count = model.token_count def too_big(self, messages): sized = self.tokenize(messages) @@ -96,17 +103,22 @@ class ChatSummary: if not content.endswith("\n"): content += "\n" - messages = [ + summarize_messages = [ dict(role="system", content=prompts.summarize), dict(role="user", content=content), ] - summary = simple_send_with_retries(self.model.name, messages) - if summary is None: - raise ValueError(f"summarizer unexpectedly failed for {self.model.name}") - summary = prompts.summary_prefix + summary + for model in self.models: + self.set_current_model(model) + try: + summary = simple_send_with_retries(model.name, summarize_messages) + if summary is not None: + summary = prompts.summary_prefix + summary + return [dict(role="user", content=summary)] + except Exception as e: + print(f"Summarization failed for model {model.name}: {str(e)}") - return [dict(role="user", content=summary)] + raise ValueError(f"summarizer unexpectedly failed for all models") def main(): @@ -114,8 +126,9 @@ def main(): parser.add_argument("filename", help="Markdown file to parse") args = parser.parse_args() - model = models.Model("gpt-3.5-turbo") - summarizer = ChatSummary(model) + model_names = ["gpt-3.5-turbo", "gpt-4"] # Add more model names as needed + model_list = [models.Model(name) for name in model_names] + summarizer = ChatSummary(model_list) with open(args.filename, "r") as f: text = f.read()