Add support for multiple models in ChatSummary class

This commit is contained in:
Paul Gauthier 2024-07-30 11:57:34 -03:00 committed by Paul Gauthier (aider)
parent 3c04850339
commit 2e4de1a0d3

View file

@ -6,10 +6,17 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary:
def __init__(self, model=None, max_tokens=1024):
self.token_count = model.token_count
def __init__(self, models=None, max_tokens=1024):
if not models:
raise ValueError("At least one model must be provided")
self.models = models if isinstance(models, list) else [models]
self.max_tokens = max_tokens
self.model = model
self.current_model = None
self.set_current_model(self.models[0])
def set_current_model(self, model):
self.current_model = model
self.token_count = model.token_count
def too_big(self, messages):
sized = self.tokenize(messages)
@ -96,17 +103,22 @@ class ChatSummary:
if not content.endswith("\n"):
content += "\n"
messages = [
summarize_messages = [
dict(role="system", content=prompts.summarize),
dict(role="user", content=content),
]
summary = simple_send_with_retries(self.model.name, messages)
if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary
for model in self.models:
self.set_current_model(model)
try:
summary = simple_send_with_retries(model.name, summarize_messages)
if summary is not None:
summary = prompts.summary_prefix + summary
return [dict(role="user", content=summary)]
except Exception as e:
print(f"Summarization failed for model {model.name}: {str(e)}")
return [dict(role="user", content=summary)]
raise ValueError(f"summarizer unexpectedly failed for all models")
def main():
@ -114,8 +126,9 @@ def main():
parser.add_argument("filename", help="Markdown file to parse")
args = parser.parse_args()
model = models.Model("gpt-3.5-turbo")
summarizer = ChatSummary(model)
model_names = ["gpt-3.5-turbo", "gpt-4"] # Add more model names as needed
model_list = [models.Model(name) for name in model_names]
summarizer = ChatSummary(model_list)
with open(args.filename, "r") as f:
text = f.read()