mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
parent
b107db98fa
commit
57ab2cc9da
8 changed files with 13 additions and 21 deletions
|
@ -589,7 +589,7 @@ class Coder:
|
|||
|
||||
def send(self, messages, model=None, functions=None):
|
||||
if not model:
|
||||
model = self.main_model
|
||||
model = self.main_model.name
|
||||
|
||||
self.partial_response_content = ""
|
||||
self.partial_response_function_call = dict()
|
||||
|
|
|
@ -85,7 +85,7 @@ class ChatSummary:
|
|||
dict(role="user", content=content),
|
||||
]
|
||||
|
||||
summary = simple_send_with_retries(self.client, self.model, messages)
|
||||
summary = simple_send_with_retries(self.client, self.model.name, messages)
|
||||
if summary is None:
|
||||
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
||||
summary = prompts.summary_prefix + summary
|
||||
|
|
|
@ -189,6 +189,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
|||
metavar="OPENAI_API_VERSION",
|
||||
help="Specify the api_version",
|
||||
)
|
||||
# TODO: use deployment_id
|
||||
model_group.add_argument(
|
||||
"--openai-api-deployment-id",
|
||||
metavar="OPENAI_API_DEPLOYMENT_ID",
|
||||
|
@ -506,9 +507,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
|||
|
||||
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
|
||||
|
||||
main_model = models.Model.create(
|
||||
args.model, client, deployment_id=args.openai_api_deployment_id
|
||||
)
|
||||
main_model = models.Model.create(args.model, client)
|
||||
|
||||
try:
|
||||
coder = Coder.create(
|
||||
|
|
|
@ -16,13 +16,13 @@ class Model:
|
|||
completion_price = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, name, client=None, deployment_id=None):
|
||||
def create(cls, name, client=None):
|
||||
from .openai import OpenAIModel
|
||||
from .openrouter import OpenRouterModel
|
||||
|
||||
if client and client.base_url.host == "openrouter.ai":
|
||||
return OpenRouterModel(client, name)
|
||||
return OpenAIModel(name, deployment_id=deployment_id)
|
||||
return OpenAIModel(name)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
|
|
@ -13,9 +13,8 @@ known_tokens = {
|
|||
|
||||
|
||||
class OpenAIModel(Model):
|
||||
def __init__(self, name, deployment_id=None):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.deployment_id = deployment_id
|
||||
|
||||
tokens = None
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ class GitRepo:
|
|||
]
|
||||
|
||||
for model in models.Model.commit_message_models():
|
||||
commit_message = simple_send_with_retries(self.client, model, messages)
|
||||
commit_message = simple_send_with_retries(self.client, model.name, messages)
|
||||
if commit_message:
|
||||
break
|
||||
|
||||
|
|
|
@ -28,15 +28,10 @@ CACHE = None
|
|||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||
),
|
||||
)
|
||||
def send_with_retries(client, model, messages, functions, stream):
|
||||
def send_with_retries(client, model_name, messages, functions, stream):
|
||||
if not client:
|
||||
raise ValueError("No openai client provided")
|
||||
|
||||
if model.deployment_id:
|
||||
model_name = model.deployment_id
|
||||
else:
|
||||
model_name = model.name
|
||||
|
||||
kwargs = dict(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
|
@ -62,11 +57,11 @@ def send_with_retries(client, model, messages, functions, stream):
|
|||
return hash_object, res
|
||||
|
||||
|
||||
def simple_send_with_retries(client, model, messages):
|
||||
def simple_send_with_retries(client, model_name, messages):
|
||||
try:
|
||||
_hash, response = send_with_retries(
|
||||
client=client,
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
messages=messages,
|
||||
functions=None,
|
||||
stream=False,
|
||||
|
|
|
@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch
|
|||
import httpx
|
||||
import openai
|
||||
|
||||
from aider.models import Model
|
||||
from aider.sendchat import send_with_retries
|
||||
|
||||
|
||||
|
@ -28,7 +27,7 @@ class TestSendChat(unittest.TestCase):
|
|||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
|
||||
send_with_retries(mock_client, "model", ["message"], None, False)
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||
|
@ -43,5 +42,5 @@ class TestSendChat(unittest.TestCase):
|
|||
]
|
||||
|
||||
# Call the send_with_retries method
|
||||
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
|
||||
send_with_retries(mock_client, "model", ["message"], None, False)
|
||||
mock_print.assert_called_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue