Revert "implement deployment id"

This reverts commit b107db98fa.
This commit is contained in:
Paul Gauthier 2023-12-06 09:20:53 -08:00
parent b107db98fa
commit 57ab2cc9da
8 changed files with 13 additions and 21 deletions

View file

@ -589,7 +589,7 @@ class Coder:
def send(self, messages, model=None, functions=None): def send(self, messages, model=None, functions=None):
if not model: if not model:
model = self.main_model model = self.main_model.name
self.partial_response_content = "" self.partial_response_content = ""
self.partial_response_function_call = dict() self.partial_response_function_call = dict()

View file

@ -85,7 +85,7 @@ class ChatSummary:
dict(role="user", content=content), dict(role="user", content=content),
] ]
summary = simple_send_with_retries(self.client, self.model, messages) summary = simple_send_with_retries(self.client, self.model.name, messages)
if summary is None: if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}") raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary summary = prompts.summary_prefix + summary

View file

@ -189,6 +189,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
metavar="OPENAI_API_VERSION", metavar="OPENAI_API_VERSION",
help="Specify the api_version", help="Specify the api_version",
) )
# TODO: use deployment_id
model_group.add_argument( model_group.add_argument(
"--openai-api-deployment-id", "--openai-api-deployment-id",
metavar="OPENAI_API_DEPLOYMENT_ID", metavar="OPENAI_API_DEPLOYMENT_ID",
@ -506,9 +507,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs) client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
main_model = models.Model.create( main_model = models.Model.create(args.model, client)
args.model, client, deployment_id=args.openai_api_deployment_id
)
try: try:
coder = Coder.create( coder = Coder.create(

View file

@ -16,13 +16,13 @@ class Model:
completion_price = None completion_price = None
@classmethod @classmethod
def create(cls, name, client=None, deployment_id=None): def create(cls, name, client=None):
from .openai import OpenAIModel from .openai import OpenAIModel
from .openrouter import OpenRouterModel from .openrouter import OpenRouterModel
if client and client.base_url.host == "openrouter.ai": if client and client.base_url.host == "openrouter.ai":
return OpenRouterModel(client, name) return OpenRouterModel(client, name)
return OpenAIModel(name, deployment_id=deployment_id) return OpenAIModel(name)
def __str__(self): def __str__(self):
return self.name return self.name

View file

@ -13,9 +13,8 @@ known_tokens = {
class OpenAIModel(Model): class OpenAIModel(Model):
def __init__(self, name, deployment_id=None): def __init__(self, name):
self.name = name self.name = name
self.deployment_id = deployment_id
tokens = None tokens = None

View file

@ -119,7 +119,7 @@ class GitRepo:
] ]
for model in models.Model.commit_message_models(): for model in models.Model.commit_message_models():
commit_message = simple_send_with_retries(self.client, model, messages) commit_message = simple_send_with_retries(self.client, model.name, messages)
if commit_message: if commit_message:
break break

View file

@ -28,15 +28,10 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds." f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
), ),
) )
def send_with_retries(client, model, messages, functions, stream): def send_with_retries(client, model_name, messages, functions, stream):
if not client: if not client:
raise ValueError("No openai client provided") raise ValueError("No openai client provided")
if model.deployment_id:
model_name = model.deployment_id
else:
model_name = model.name
kwargs = dict( kwargs = dict(
model=model_name, model=model_name,
messages=messages, messages=messages,
@ -62,11 +57,11 @@ def send_with_retries(client, model, messages, functions, stream):
return hash_object, res return hash_object, res
def simple_send_with_retries(client, model, messages): def simple_send_with_retries(client, model_name, messages):
try: try:
_hash, response = send_with_retries( _hash, response = send_with_retries(
client=client, client=client,
model=model, model_name=model_name,
messages=messages, messages=messages,
functions=None, functions=None,
stream=False, stream=False,

View file

@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch
import httpx import httpx
import openai import openai
from aider.models import Model
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
@ -28,7 +27,7 @@ class TestSendChat(unittest.TestCase):
] ]
# Call the send_with_retries method # Call the send_with_retries method
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False) send_with_retries(mock_client, "model", ["message"], None, False)
mock_print.assert_called_once() mock_print.assert_called_once()
@patch("aider.sendchat.openai.ChatCompletion.create") @patch("aider.sendchat.openai.ChatCompletion.create")
@ -43,5 +42,5 @@ class TestSendChat(unittest.TestCase):
] ]
# Call the send_with_retries method # Call the send_with_retries method
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False) send_with_retries(mock_client, "model", ["message"], None, False)
mock_print.assert_called_once() mock_print.assert_called_once()