diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index 4c9f8eca9..7821ac202 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -589,7 +589,7 @@ class Coder: def send(self, messages, model=None, functions=None): if not model: - model = self.main_model.name + model = self.main_model self.partial_response_content = "" self.partial_response_function_call = dict() diff --git a/aider/history.py b/aider/history.py index d1ee70ede..9fdaf9c14 100644 --- a/aider/history.py +++ b/aider/history.py @@ -85,7 +85,7 @@ class ChatSummary: dict(role="user", content=content), ] - summary = simple_send_with_retries(self.client, self.model.name, messages) + summary = simple_send_with_retries(self.client, self.model, messages) if summary is None: raise ValueError(f"summarizer unexpectedly failed for {self.model.name}") summary = prompts.summary_prefix + summary diff --git a/aider/main.py b/aider/main.py index d4c8cb5c7..63ed0753d 100644 --- a/aider/main.py +++ b/aider/main.py @@ -189,7 +189,6 @@ def main(argv=None, input=None, output=None, force_git_root=None): metavar="OPENAI_API_VERSION", help="Specify the api_version", ) - # TODO: use deployment_id model_group.add_argument( "--openai-api-deployment-id", metavar="OPENAI_API_DEPLOYMENT_ID", @@ -507,7 +506,9 @@ def main(argv=None, input=None, output=None, force_git_root=None): client = openai.OpenAI(api_key=args.openai_api_key, **kwargs) - main_model = models.Model.create(args.model, client) + main_model = models.Model.create( + args.model, client, deployment_id=args.openai_api_deployment_id + ) try: coder = Coder.create( diff --git a/aider/models/model.py b/aider/models/model.py index 70f09d313..7eb3be88c 100644 --- a/aider/models/model.py +++ b/aider/models/model.py @@ -16,13 +16,13 @@ class Model: completion_price = None @classmethod - def create(cls, name, client=None): + def create(cls, name, client=None, deployment_id=None): from .openai import OpenAIModel from .openrouter import OpenRouterModel if client and client.base_url.host == "openrouter.ai": return OpenRouterModel(client, name) - return OpenAIModel(name) + return OpenAIModel(name, deployment_id=deployment_id) def __str__(self): return self.name diff --git a/aider/models/openai.py b/aider/models/openai.py index 1c6286d63..435048868 100644 --- a/aider/models/openai.py +++ b/aider/models/openai.py @@ -13,8 +13,9 @@ known_tokens = { class OpenAIModel(Model): - def __init__(self, name): + def __init__(self, name, deployment_id=None): self.name = name + self.deployment_id = deployment_id tokens = None diff --git a/aider/repo.py b/aider/repo.py index 6943c5568..7fd096984 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -119,7 +119,7 @@ class GitRepo: ] for model in models.Model.commit_message_models(): - commit_message = simple_send_with_retries(self.client, model.name, messages) + commit_message = simple_send_with_retries(self.client, model, messages) if commit_message: break diff --git a/aider/sendchat.py b/aider/sendchat.py index c770ef087..baba6e682 100644 --- a/aider/sendchat.py +++ b/aider/sendchat.py @@ -28,10 +28,15 @@ CACHE = None f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." ), ) -def send_with_retries(client, model_name, messages, functions, stream): +def send_with_retries(client, model, messages, functions, stream): if not client: raise ValueError("No openai client provided") + if model.deployment_id: + model_name = model.deployment_id + else: + model_name = model.name + kwargs = dict( model=model_name, messages=messages, @@ -57,11 +62,11 @@ def send_with_retries(client, model_name, messages, functions, stream): return hash_object, res -def simple_send_with_retries(client, model_name, messages): +def simple_send_with_retries(client, model, messages): try: _hash, response = send_with_retries( client=client, - model_name=model_name, + model=model, messages=messages, functions=None, stream=False, diff --git a/tests/test_sendchat.py b/tests/test_sendchat.py index 7bb8fcfab..2613d1f3a 100644 --- a/tests/test_sendchat.py +++ b/tests/test_sendchat.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import httpx import openai +from aider.models import Model from aider.sendchat import send_with_retries @@ -27,7 +28,7 @@ class TestSendChat(unittest.TestCase): ] # Call the send_with_retries method - send_with_retries(mock_client, "model", ["message"], None, False) + send_with_retries(mock_client, Model.weak_model(), ["message"], None, False) mock_print.assert_called_once() @patch("aider.sendchat.openai.ChatCompletion.create") @@ -42,5 +43,5 @@ class TestSendChat(unittest.TestCase): ] # Call the send_with_retries method - send_with_retries(mock_client, "model", ["message"], None, False) + send_with_retries(mock_client, Model.weak_model(), ["message"], None, False) mock_print.assert_called_once()