mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
parent
b107db98fa
commit
57ab2cc9da
8 changed files with 13 additions and 21 deletions
|
@ -589,7 +589,7 @@ class Coder:
|
||||||
|
|
||||||
def send(self, messages, model=None, functions=None):
|
def send(self, messages, model=None, functions=None):
|
||||||
if not model:
|
if not model:
|
||||||
model = self.main_model
|
model = self.main_model.name
|
||||||
|
|
||||||
self.partial_response_content = ""
|
self.partial_response_content = ""
|
||||||
self.partial_response_function_call = dict()
|
self.partial_response_function_call = dict()
|
||||||
|
|
|
@ -85,7 +85,7 @@ class ChatSummary:
|
||||||
dict(role="user", content=content),
|
dict(role="user", content=content),
|
||||||
]
|
]
|
||||||
|
|
||||||
summary = simple_send_with_retries(self.client, self.model, messages)
|
summary = simple_send_with_retries(self.client, self.model.name, messages)
|
||||||
if summary is None:
|
if summary is None:
|
||||||
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
||||||
summary = prompts.summary_prefix + summary
|
summary = prompts.summary_prefix + summary
|
||||||
|
|
|
@ -189,6 +189,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
metavar="OPENAI_API_VERSION",
|
metavar="OPENAI_API_VERSION",
|
||||||
help="Specify the api_version",
|
help="Specify the api_version",
|
||||||
)
|
)
|
||||||
|
# TODO: use deployment_id
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--openai-api-deployment-id",
|
"--openai-api-deployment-id",
|
||||||
metavar="OPENAI_API_DEPLOYMENT_ID",
|
metavar="OPENAI_API_DEPLOYMENT_ID",
|
||||||
|
@ -506,9 +507,7 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
|
|
||||||
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
|
client = openai.OpenAI(api_key=args.openai_api_key, **kwargs)
|
||||||
|
|
||||||
main_model = models.Model.create(
|
main_model = models.Model.create(args.model, client)
|
||||||
args.model, client, deployment_id=args.openai_api_deployment_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
coder = Coder.create(
|
coder = Coder.create(
|
||||||
|
|
|
@ -16,13 +16,13 @@ class Model:
|
||||||
completion_price = None
|
completion_price = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, name, client=None, deployment_id=None):
|
def create(cls, name, client=None):
|
||||||
from .openai import OpenAIModel
|
from .openai import OpenAIModel
|
||||||
from .openrouter import OpenRouterModel
|
from .openrouter import OpenRouterModel
|
||||||
|
|
||||||
if client and client.base_url.host == "openrouter.ai":
|
if client and client.base_url.host == "openrouter.ai":
|
||||||
return OpenRouterModel(client, name)
|
return OpenRouterModel(client, name)
|
||||||
return OpenAIModel(name, deployment_id=deployment_id)
|
return OpenAIModel(name)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.name
|
||||||
|
|
|
@ -13,9 +13,8 @@ known_tokens = {
|
||||||
|
|
||||||
|
|
||||||
class OpenAIModel(Model):
|
class OpenAIModel(Model):
|
||||||
def __init__(self, name, deployment_id=None):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.deployment_id = deployment_id
|
|
||||||
|
|
||||||
tokens = None
|
tokens = None
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ class GitRepo:
|
||||||
]
|
]
|
||||||
|
|
||||||
for model in models.Model.commit_message_models():
|
for model in models.Model.commit_message_models():
|
||||||
commit_message = simple_send_with_retries(self.client, model, messages)
|
commit_message = simple_send_with_retries(self.client, model.name, messages)
|
||||||
if commit_message:
|
if commit_message:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -28,15 +28,10 @@ CACHE = None
|
||||||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def send_with_retries(client, model, messages, functions, stream):
|
def send_with_retries(client, model_name, messages, functions, stream):
|
||||||
if not client:
|
if not client:
|
||||||
raise ValueError("No openai client provided")
|
raise ValueError("No openai client provided")
|
||||||
|
|
||||||
if model.deployment_id:
|
|
||||||
model_name = model.deployment_id
|
|
||||||
else:
|
|
||||||
model_name = model.name
|
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -62,11 +57,11 @@ def send_with_retries(client, model, messages, functions, stream):
|
||||||
return hash_object, res
|
return hash_object, res
|
||||||
|
|
||||||
|
|
||||||
def simple_send_with_retries(client, model, messages):
|
def simple_send_with_retries(client, model_name, messages):
|
||||||
try:
|
try:
|
||||||
_hash, response = send_with_retries(
|
_hash, response = send_with_retries(
|
||||||
client=client,
|
client=client,
|
||||||
model=model,
|
model_name=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
functions=None,
|
functions=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|
|
@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from aider.models import Model
|
|
||||||
from aider.sendchat import send_with_retries
|
from aider.sendchat import send_with_retries
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +27,7 @@ class TestSendChat(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Call the send_with_retries method
|
# Call the send_with_retries method
|
||||||
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
|
send_with_retries(mock_client, "model", ["message"], None, False)
|
||||||
mock_print.assert_called_once()
|
mock_print.assert_called_once()
|
||||||
|
|
||||||
@patch("aider.sendchat.openai.ChatCompletion.create")
|
@patch("aider.sendchat.openai.ChatCompletion.create")
|
||||||
|
@ -43,5 +42,5 @@ class TestSendChat(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Call the send_with_retries method
|
# Call the send_with_retries method
|
||||||
send_with_retries(mock_client, Model.weak_model(), ["message"], None, False)
|
send_with_retries(mock_client, "model", ["message"], None, False)
|
||||||
mock_print.assert_called_once()
|
mock_print.assert_called_once()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue