aider.repo.simple_send_with_retries

This commit is contained in:
Paul Gauthier 2023-07-21 16:20:27 -03:00
parent e2a32fec7e
commit 661a521693
3 changed files with 23 additions and 32 deletions

View file

@ -2,10 +2,9 @@ import os
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
import git import git
import openai
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.sendchat import send_with_retries from aider.sendchat import simple_send_with_retries
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
@ -104,19 +103,10 @@ class GitRepo:
dict(role="user", content=content), dict(role="user", content=content),
] ]
commit_message = None
for model in [models.GPT35.name, models.GPT35_16k.name]: for model in [models.GPT35.name, models.GPT35_16k.name]:
try: commit_message = simple_send_with_retries(model, messages)
_hash, response = send_with_retries( if commit_message:
model=models.GPT35.name,
messages=messages,
functions=None,
stream=False,
)
commit_message = response.choices[0].message.content
break break
except (AttributeError, openai.error.InvalidRequestError):
pass
if not commit_message: if not commit_message:
self.io.tool_error("Failed to generate commit message!") self.io.tool_error("Failed to generate commit message!")

View file

@ -42,3 +42,16 @@ def send_with_retries(model, messages, functions, stream):
res = openai.ChatCompletion.create(**kwargs) res = openai.ChatCompletion.create(**kwargs)
return hash_object, res return hash_object, res
def simple_send_with_retries(model, messages):
try:
_hash, response = send_with_retries(
model=model,
messages=messages,
functions=None,
stream=False,
)
return response.choices[0].message.content
except (AttributeError, openai.error.InvalidRequestError):
return

View file

@ -2,7 +2,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import patch
import git import git
@ -12,13 +12,9 @@ from aider.repo import GitRepo
class TestRepo(unittest.TestCase): class TestRepo(unittest.TestCase):
@patch("aider.repo.send_with_retries") @patch("aider.repo.simple_send_with_retries")
def test_get_commit_message(self, mock_send): def test_get_commit_message(self, mock_send):
# Set the return value of the mocked function mock_send.return_value = "a good commit message"
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "a good commit message"
mock_send.return_value = (None, mock_response)
repo = GitRepo(InputOutput(), None) repo = GitRepo(InputOutput(), None)
# Call the get_commit_message method with dummy diff and context # Call the get_commit_message method with dummy diff and context
@ -27,13 +23,9 @@ class TestRepo(unittest.TestCase):
# Assert that the returned message is the expected one # Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message") self.assertEqual(result, "a good commit message")
@patch("aider.repo.send_with_retries") @patch("aider.repo.simple_send_with_retries")
def test_get_commit_message_strip_quotes(self, mock_send): def test_get_commit_message_strip_quotes(self, mock_send):
# Set the return value of the mocked function mock_send.return_value = '"a good commit message"'
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = '"a good commit message"'
mock_send.return_value = (None, mock_response)
repo = GitRepo(InputOutput(), None) repo = GitRepo(InputOutput(), None)
# Call the get_commit_message method with dummy diff and context # Call the get_commit_message method with dummy diff and context
@ -42,13 +34,9 @@ class TestRepo(unittest.TestCase):
# Assert that the returned message is the expected one # Assert that the returned message is the expected one
self.assertEqual(result, "a good commit message") self.assertEqual(result, "a good commit message")
@patch("aider.repo.send_with_retries") @patch("aider.repo.simple_send_with_retries")
def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send):
# Set the return value of the mocked function mock_send.return_value = 'a good "commit message"'
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = 'a good "commit message"'
mock_send.return_value = (None, mock_response)
repo = GitRepo(InputOutput(), None) repo = GitRepo(InputOutput(), None)
# Call the get_commit_message method with dummy diff and context # Call the get_commit_message method with dummy diff and context