mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-28 08:14:59 +00:00
Add support for multiple models in ChatSummary class
This commit is contained in:
parent
3c04850339
commit
2e4de1a0d3
1 changed files with 24 additions and 11 deletions
|
@ -6,10 +6,17 @@ from aider.sendchat import simple_send_with_retries
|
|||
|
||||
|
||||
class ChatSummary:
|
||||
def __init__(self, model=None, max_tokens=1024):
|
||||
self.token_count = model.token_count
|
||||
def __init__(self, models=None, max_tokens=1024):
|
||||
if not models:
|
||||
raise ValueError("At least one model must be provided")
|
||||
self.models = models if isinstance(models, list) else [models]
|
||||
self.max_tokens = max_tokens
|
||||
self.model = model
|
||||
self.current_model = None
|
||||
self.set_current_model(self.models[0])
|
||||
|
||||
def set_current_model(self, model):
|
||||
self.current_model = model
|
||||
self.token_count = model.token_count
|
||||
|
||||
def too_big(self, messages):
|
||||
sized = self.tokenize(messages)
|
||||
|
@ -96,17 +103,22 @@ class ChatSummary:
|
|||
if not content.endswith("\n"):
|
||||
content += "\n"
|
||||
|
||||
messages = [
|
||||
summarize_messages = [
|
||||
dict(role="system", content=prompts.summarize),
|
||||
dict(role="user", content=content),
|
||||
]
|
||||
|
||||
summary = simple_send_with_retries(self.model.name, messages)
|
||||
if summary is None:
|
||||
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
||||
summary = prompts.summary_prefix + summary
|
||||
for model in self.models:
|
||||
self.set_current_model(model)
|
||||
try:
|
||||
summary = simple_send_with_retries(model.name, summarize_messages)
|
||||
if summary is not None:
|
||||
summary = prompts.summary_prefix + summary
|
||||
return [dict(role="user", content=summary)]
|
||||
except Exception as e:
|
||||
print(f"Summarization failed for model {model.name}: {str(e)}")
|
||||
|
||||
return [dict(role="user", content=summary)]
|
||||
raise ValueError(f"summarizer unexpectedly failed for all models")
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -114,8 +126,9 @@ def main():
|
|||
parser.add_argument("filename", help="Markdown file to parse")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = models.Model("gpt-3.5-turbo")
|
||||
summarizer = ChatSummary(model)
|
||||
model_names = ["gpt-3.5-turbo", "gpt-4"] # Add more model names as needed
|
||||
model_list = [models.Model(name) for name in model_names]
|
||||
summarizer = ChatSummary(model_list)
|
||||
|
||||
with open(args.filename, "r") as f:
|
||||
text = f.read()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue