implement deployment id

This commit is contained in:
Paul Gauthier 2023-12-05 11:31:17 -08:00
parent bf03f43b44
commit b107db98fa
8 changed files with 21 additions and 13 deletions

View file

@ -589,7 +589,7 @@ class Coder:
def send(self, messages, model=None, functions=None): def send(self, messages, model=None, functions=None):
if not model: if not model:
model = self.main_model.name model = self.main_model
self.partial_response_content = "" self.partial_response_content = ""
self.partial_response_function_call = dict() self.partial_response_function_call = dict()

View file

@ -85,7 +85,7 @@ class ChatSummary:
dict(role="user", content=content), dict(role="user", content=content),
] ]
summary = simple_send_with_retries(self.client, self.model.name, messages) summary = simple_send_with_retries(self.client, self.model, messages)
if summary is None: if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}") raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary summary = prompts.summary_prefix + summary

View file

@ -189,7 +189,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
metavar="OPENAI_API_VERSION", metavar="OPENAI_API_VERSION",
help="Specify the api_version", help="Specify the api_version",
) )
# TODO: use deployment_id
model_group.add_argument( model_group.add_argument(
"--openai-api-deployment-id", "--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID", metavar="OPENAI_API_DEPLOYMENT_ID",
@ -507,7 +506,9 @@ def main(argv=None, input=None, output=None, force_git_root=None):
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs) client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
main_model = models.Model.create(args.model, client) main_model = models.Model.create(
args.model, client, deployment_id=args.openai_api_deployment_id
)
try: try:
coder = Coder.create( coder = Coder.create(

View file

@ -16,13 +16,13 @@ class Model:
completion_price = None completion_price = None
@classmethod @classmethod
def create(cls, name, client=None): def create(cls, name, client=None, deployment_id=None):
from .openai import OpenAIModel from .openai import OpenAIModel
from .openrouter import OpenRouterModel from .openrouter import OpenRouterModel
if client and client.base_url.host == "openrouter.ai": if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name) return OpenRouterModel(client, name)
return OpenAIModel(name) return OpenAIModel(name, deployment_id=deployment_id)
def __str__(self): def __str__(self):
return self.name return self.name

View file

@ -13,8 +13,9 @@ known_tokens = {
class OpenAIModel(Model): class OpenAIModel(Model):
def __init__(self, name): def __init__(self, name, deployment_id=None):
self.name = name self.name = name
self.deployment_id = deployment_id
tokens = None tokens = None

View file

@ -119,7 +119,7 @@ class GitRepo:
] ]
for model in models.Model.commit_message_models(): for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(self.client, model.name, messages) commit_message = simple_send_with_retries(self.client, model, messages)
if commit_message: if commit_message:
break break

View file

@ -28,10 +28,15 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
), ),
) )
def send_with_retries(client, model_name, messages, functions, stream): def send_with_retries(client, model, messages, functions, stream):
if not client: if not client:
raise ValueError("No openai client provided") raise ValueError("No openai client provided")
if model.deployment_id:
model_name = model.deployment_id
else:
model_name = model.name
kwargs = dict( kwargs = dict(
model=model_name, model=model_name,
messages=messages, messages=messages,
@ -57,11 +62,11 @@ def send_with_retries(client, model_name, messages, functions, stream):
return hash_object, res return hash_object, res
def simple_send_with_retries(client, model_name, messages): def simple_send_with_retries(client, model, messages):
try: try:
_hash, response = send_with_retries( _hash, response = send_with_retries(
client=client, client=client,
model_name=model_name, model=model,
messages=messages, messages=messages,
functions=None, functions=None,
stream=False, stream=False,

View file

@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import httpx import httpx
import openai import openai
from aider.models import Model
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
@ -27,7 +28,7 @@ class TestSendChat(unittest.TestCase):
] ]
# Call the send_with_retries method # Call the send_with_retries method
send_with_retries(mock_client, "model", ["message"], None, False) send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
mock_print.assert_called_once() mock_print.assert_called_once()
@patch("aider.sendchat.openai.ChatCompletion.create") @patch("aider.sendchat.openai.ChatCompletion.create")
@ -42,5 +43,5 @@ class TestSendChat(unittest.TestCase):
] ]
# Call the send_with_retries method # Call the send_with_retries method
send_with_retries(mock_client, "model", ["message"], None, False) send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
mock_print.assert_called_once() mock_print.assert_called_once()