mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-29 16:54:59 +00:00
aider.repo.simple_send_with_retries
This commit is contained in:
parent
e2a32fec7e
commit
661a521693
3 changed files with 23 additions and 32 deletions
|
@ -2,10 +2,9 @@ import os
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path, PurePosixPath
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import openai
|
|
||||||
|
|
||||||
from aider import models, prompts, utils
|
from aider import models, prompts, utils
|
||||||
from aider.sendchat import send_with_retries
|
from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
from .dump import dump # noqa: F401
|
from .dump import dump # noqa: F401
|
||||||
|
|
||||||
|
@ -104,19 +103,10 @@ class GitRepo:
|
||||||
dict(role="user", content=content),
|
dict(role="user", content=content),
|
||||||
]
|
]
|
||||||
|
|
||||||
commit_message = None
|
|
||||||
for model in [models.GPT35.name, models.GPT35_16k.name]:
|
for model in [models.GPT35.name, models.GPT35_16k.name]:
|
||||||
try:
|
commit_message = simple_send_with_retries(model, messages)
|
||||||
_hash, response = send_with_retries(
|
if commit_message:
|
||||||
model=models.GPT35.name,
|
|
||||||
messages=messages,
|
|
||||||
functions=None,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
commit_message = response.choices[0].message.content
|
|
||||||
break
|
break
|
||||||
except (AttributeError, openai.error.InvalidRequestError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not commit_message:
|
if not commit_message:
|
||||||
self.io.tool_error("Failed to generate commit message!")
|
self.io.tool_error("Failed to generate commit message!")
|
||||||
|
|
|
@ -42,3 +42,16 @@ def send_with_retries(model, messages, functions, stream):
|
||||||
|
|
||||||
res = openai.ChatCompletion.create(**kwargs)
|
res = openai.ChatCompletion.create(**kwargs)
|
||||||
return hash_object, res
|
return hash_object, res
|
||||||
|
|
||||||
|
|
||||||
|
def simple_send_with_retries(model, messages):
|
||||||
|
try:
|
||||||
|
_hash, response = send_with_retries(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
functions=None,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except (AttributeError, openai.error.InvalidRequestError):
|
||||||
|
return
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
|
@ -12,13 +12,9 @@ from aider.repo import GitRepo
|
||||||
|
|
||||||
|
|
||||||
class TestRepo(unittest.TestCase):
|
class TestRepo(unittest.TestCase):
|
||||||
@patch("aider.repo.send_with_retries")
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
def test_get_commit_message(self, mock_send):
|
def test_get_commit_message(self, mock_send):
|
||||||
# Set the return value of the mocked function
|
mock_send.return_value = "a good commit message"
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.choices = [MagicMock()]
|
|
||||||
mock_response.choices[0].message.content = "a good commit message"
|
|
||||||
mock_send.return_value = (None, mock_response)
|
|
||||||
|
|
||||||
repo = GitRepo(InputOutput(), None)
|
repo = GitRepo(InputOutput(), None)
|
||||||
# Call the get_commit_message method with dummy diff and context
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
@ -27,13 +23,9 @@ class TestRepo(unittest.TestCase):
|
||||||
# Assert that the returned message is the expected one
|
# Assert that the returned message is the expected one
|
||||||
self.assertEqual(result, "a good commit message")
|
self.assertEqual(result, "a good commit message")
|
||||||
|
|
||||||
@patch("aider.repo.send_with_retries")
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
def test_get_commit_message_strip_quotes(self, mock_send):
|
def test_get_commit_message_strip_quotes(self, mock_send):
|
||||||
# Set the return value of the mocked function
|
mock_send.return_value = '"a good commit message"'
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.choices = [MagicMock()]
|
|
||||||
mock_response.choices[0].message.content = '"a good commit message"'
|
|
||||||
mock_send.return_value = (None, mock_response)
|
|
||||||
|
|
||||||
repo = GitRepo(InputOutput(), None)
|
repo = GitRepo(InputOutput(), None)
|
||||||
# Call the get_commit_message method with dummy diff and context
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
@ -42,13 +34,9 @@ class TestRepo(unittest.TestCase):
|
||||||
# Assert that the returned message is the expected one
|
# Assert that the returned message is the expected one
|
||||||
self.assertEqual(result, "a good commit message")
|
self.assertEqual(result, "a good commit message")
|
||||||
|
|
||||||
@patch("aider.repo.send_with_retries")
|
@patch("aider.repo.simple_send_with_retries")
|
||||||
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
|
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
|
||||||
# Set the return value of the mocked function
|
mock_send.return_value = 'a good "commit message"'
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.choices = [MagicMock()]
|
|
||||||
mock_response.choices[0].message.content = 'a good "commit message"'
|
|
||||||
mock_send.return_value = (None, mock_response)
|
|
||||||
|
|
||||||
repo = GitRepo(InputOutput(), None)
|
repo = GitRepo(InputOutput(), None)
|
||||||
# Call the get_commit_message method with dummy diff and context
|
# Call the get_commit_message method with dummy diff and context
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue