diff --git a/aider/repo.py b/aider/repo.py index 0b01e9978..5ae06baed 100644 --- a/aider/repo.py +++ b/aider/repo.py @@ -2,10 +2,9 @@ import os from pathlib import Path, PurePosixPath import git -import openai from aider import models, prompts, utils -from aider.sendchat import send_with_retries +from aider.sendchat import simple_send_with_retries from .dump import dump # noqa: F401 @@ -104,19 +103,10 @@ class GitRepo: dict(role="user", content=content), ] - commit_message = None for model in [models.GPT35.name, models.GPT35_16k.name]: - try: - _hash, response = send_with_retries( - model=models.GPT35.name, - messages=messages, - functions=None, - stream=False, - ) - commit_message = response.choices[0].message.content + commit_message = simple_send_with_retries(model, messages) + if commit_message: break - except (AttributeError, openai.error.InvalidRequestError): - pass if not commit_message: self.io.tool_error("Failed to generate commit message!") diff --git a/aider/sendchat.py b/aider/sendchat.py index 9a4414230..441447863 100644 --- a/aider/sendchat.py +++ b/aider/sendchat.py @@ -42,3 +42,16 @@ def send_with_retries(model, messages, functions, stream): res = openai.ChatCompletion.create(**kwargs) return hash_object, res + + +def simple_send_with_retries(model, messages): + try: + _hash, response = send_with_retries( + model=model, + messages=messages, + functions=None, + stream=False, + ) + return response.choices[0].message.content + except (AttributeError, openai.error.InvalidRequestError): + return diff --git a/tests/test_repo.py b/tests/test_repo.py index ca0985886..67983a134 100644 --- a/tests/test_repo.py +++ b/tests/test_repo.py @@ -2,7 +2,7 @@ import os import tempfile import unittest from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import git @@ -12,13 +12,9 @@ from aider.repo import GitRepo class TestRepo(unittest.TestCase): - @patch("aider.repo.send_with_retries") + @patch("aider.repo.simple_send_with_retries") def test_get_commit_message(self, mock_send): - # Set the return value of the mocked function - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "a good commit message" - mock_send.return_value = (None, mock_response) + mock_send.return_value = "a good commit message" repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context @@ -27,13 +23,9 @@ class TestRepo(unittest.TestCase): # Assert that the returned message is the expected one self.assertEqual(result, "a good commit message") - @patch("aider.repo.send_with_retries") + @patch("aider.repo.simple_send_with_retries") def test_get_commit_message_strip_quotes(self, mock_send): - # Set the return value of the mocked function - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = '"a good commit message"' - mock_send.return_value = (None, mock_response) + mock_send.return_value = '"a good commit message"' repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context @@ -42,13 +34,9 @@ class TestRepo(unittest.TestCase): # Assert that the returned message is the expected one self.assertEqual(result, "a good commit message") - @patch("aider.repo.send_with_retries") + @patch("aider.repo.simple_send_with_retries") def test_get_commit_message_no_strip_unmatched_quotes(self, mock_send): - # Set the return value of the mocked function - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = 'a good "commit message"' - mock_send.return_value = (None, mock_response) + mock_send.return_value = 'a good "commit message"' repo = GitRepo(InputOutput(), None) # Call the get_commit_message method with dummy diff and context