get_weak_model

This commit is contained in:
Paul Gauthier 2024-04-18 13:55:43 -07:00
parent 93f4a46996
commit cf2a48b21f
3 changed files with 10 additions and 11 deletions

View file

@ -212,7 +212,7 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.") self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary( self.summarizer = ChatSummary(
self.main_model.weak_model(), self.main_model.get_weak_model(),
self.main_model.max_chat_history_tokens, self.main_model.max_chat_history_tokens,
) )

View file

@ -124,7 +124,7 @@ def main():
assistant.append(line) assistant.append(line)
summarizer = ChatSummary(models.Model.weak_model()) summarizer = ChatSummary(models.Model.get_weak_model())
summary = summarizer.summarize(messages[-40:]) summary = summarizer.summarize(messages[-40:])
dump(summary) dump(summary)

View file

@ -6,7 +6,7 @@ from typing import Optional
import litellm import litellm
from PIL import Image from PIL import Image
from aider.dump import dump from aider.dump import dump # noqa: F401
DEFAULT_MODEL_NAME = "gpt-4-1106-preview" DEFAULT_MODEL_NAME = "gpt-4-1106-preview"
@ -114,11 +114,13 @@ class Model:
name = None name = None
weak_model_name = "gpt-3.5-turbo-0125" weak_model_name = "gpt-3.5-turbo-0125"
edit_format = "whole" edit_format = "whole"
use_repo_map = False use_repo_map = False
send_undo_reply = False send_undo_reply = False
max_chat_history_tokens = 1024 max_chat_history_tokens = 1024
weak_model = None
def __init__(self, model): def __init__(self, model):
self.name = model self.name = model
@ -128,8 +130,6 @@ class Model:
except Exception as err: except Exception as err:
raise NoModelInfo(str(err)) raise NoModelInfo(str(err))
dump(model, self.info)
if self.info.get("max_input_tokens", 0) < 32 * 1024: if self.info.get("max_input_tokens", 0) < 32 * 1024:
self.max_chat_history_tokens = 1024 self.max_chat_history_tokens = 1024
else: else:
@ -159,14 +159,13 @@ class Model:
def __str__(self): def __str__(self):
return self.name return self.name
def weak_model(self): def get_weak_model(self):
if self.name == self.weak_model_name: if not self.weak_model:
return self self.weak_model = Model(self.weak_model_name)
return self.weak_model
return Model(self.weak_model_name)
def commit_message_models(self): def commit_message_models(self):
return [self.weak_model()] return [self.get_weak_model()]
def tokenizer(self, text): def tokenizer(self, text):
return litellm.encode(model=self.name, text=text) return litellm.encode(model=self.name, text=text)