implement deployment id

This commit is contained in:
Paul Gauthier 2023-12-05 11:31:17 -08:00
parent bf03f43b44
commit b107db98fa
8 changed files with 21 additions and 13 deletions

View file

@ -589,7 +589,7 @@ class Coder:
def send(self, messages, model=None, functions=None):
if not model:
model = self.main_model.name
model = self.main_model
self.partial_response_content = ""
self.partial_response_function_call = dict()

View file

@ -85,7 +85,7 @@ class ChatSummary:
dict(role="user", content=content),
]
summary = simple_send_with_retries(self.client, self.model.name, messages)
summary = simple_send_with_retries(self.client, self.model, messages)
if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary

View file

@ -189,7 +189,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
metavar="OPENAI_API_VERSION",
help="Specify the api_version",
)
# TODO: use deployment_id
model_group.add_argument(
"--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID",
@ -507,7 +506,9 @@ def main(argv=None, input=None, output=None, force_git_root=None):
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
main_model = models.Model.create(args.model, client)
main_model = models.Model.create(
args.model, client, deployment_id=args.openai_api_deployment_id
)
try:
coder = Coder.create(

View file

@ -16,13 +16,13 @@ class Model:
completion_price = None
@classmethod
def create(cls, name, client=None):
def create(cls, name, client=None, deployment_id=None):
from .openai import OpenAIModel
from .openrouter import OpenRouterModel
if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name)
return OpenAIModel(name)
return OpenAIModel(name, deployment_id=deployment_id)
def __str__(self):
return self.name

View file

@ -13,8 +13,9 @@ known_tokens = {
class OpenAIModel(Model):
def __init__(self, name):
def __init__(self, name, deployment_id=None):
self.name = name
self.deployment_id = deployment_id
tokens = None

View file

@ -119,7 +119,7 @@ class GitRepo:
]
for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(self.client, model.name, messages)
commit_message = simple_send_with_retries(self.client, model, messages)
if commit_message:
break

View file

@ -28,10 +28,15 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(client, model_name, messages, functions, stream):
def send_with_retries(client, model, messages, functions, stream):
if not client:
raise ValueError("No openai client provided")
if model.deployment_id:
model_name = model.deployment_id
else:
model_name = model.name
kwargs = dict(
model=model_name,
messages=messages,
@ -57,11 +62,11 @@ def send_with_retries(client, model_name, messages, functions, stream):
return hash_object, res
def simple_send_with_retries(client, model_name, messages):
def simple_send_with_retries(client, model, messages):
try:
_hash, response = send_with_retries(
client=client,
model_name=model_name,
model=model,
messages=messages,
functions=None,
stream=False,

View file

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import httpx
import openai
from aider.models import Model
from aider.sendchat import send_with_retries
@ -27,7 +28,7 @@ class TestSendChat(unittest.TestCase):
]
# Call the send_with_retries method
send_with_retries(mock_client, "model", ["message"], None, False)
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
mock_print.assert_called_once()
@patch("aider.sendchat.openai.ChatCompletion.create")
@ -42,5 +43,5 @@ class TestSendChat(unittest.TestCase):
]
# Call the send_with_retries method
send_with_retries(mock_client, "model", ["message"], None, False)
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
mock_print.assert_called_once()