get_weak_model

This commit is contained in:
Paul Gauthier 2024-04-18 13:55:43 -07:00
parent 93f4a46996
commit cf2a48b21f
3 changed files with 10 additions and 11 deletions

View file

@ -212,7 +212,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary(
self.main_model.weak_model(),
self.main_model.get_weak_model(),
self.main_model.max_chat_history_tokens,
)

View file

@ -124,7 +124,7 @@ def main():
assistant.append(line)
summarizer = ChatSummary(models.Model.weak_model())
summarizer = ChatSummary(models.Model.get_weak_model())
summary = summarizer.summarize(messages[-40:])
dump(summary)

View file

@ -6,7 +6,7 @@ from typing import Optional
import litellm
from PIL import Image
from aider.dump import dump
from aider.dump import dump # noqa: F401
DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
@ -114,11 +114,13 @@ class Model:
name = None
weak_model_name = "gpt-3.5-turbo-0125"
edit_format = "whole"
use_repo_map = False
send_undo_reply = False
max_chat_history_tokens = 1024
weak_model = None
def __init__(self, model):
self.name = model
@ -128,8 +130,6 @@ class Model:
except Exception as err:
raise NoModelInfo(str(err))
dump(model, self.info)
if self.info.get("max_input_tokens", 0) < 32 * 1024:
self.max_chat_history_tokens = 1024
else:
@ -159,14 +159,13 @@ class Model:
def __str__(self):
return self.name
def weak_model(self):
if self.name == self.weak_model_name:
return self
return Model(self.weak_model_name)
def get_weak_model(self):
if not self.weak_model:
self.weak_model = Model(self.weak_model_name)
return self.weak_model
def commit_message_models(self):
return [self.weak_model()]
return [self.get_weak_model()]
def tokenizer(self, text):
return litellm.encode(model=self.name, text=text)