Merge branch 'main' into call-graph

This commit is contained in:
Paul Gauthier 2023-06-03 17:43:45 -07:00
commit 05e3d2bfdb
2 changed files with 67 additions and 15 deletions

View file

@ -8,6 +8,7 @@ from pathlib import Path
import git import git
import openai import openai
import requests
from openai.error import RateLimitError from openai.error import RateLimitError
from rich.console import Console from rich.console import Console
from rich.live import Live from rich.live import Live
@ -390,6 +391,24 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
def send_with_retries(self, model, messages):
while True:
try:
return openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0,
stream=True,
)
except RateLimitError as err:
self.io.tool_error(f"RateLimitError: {err}")
except requests.exceptions.ConnectionError as err:
self.io.tool_error(f"ConnectionError: {err}")
retry_after = 1
self.io.tool_error(f"Retry in {retry_after} seconds.")
time.sleep(retry_after)
def send(self, messages, model=None, silent=False): def send(self, messages, model=None, silent=False):
if not model: if not model:
model = self.main_model model = self.main_model
@ -397,21 +416,7 @@ class Coder:
self.resp = "" self.resp = ""
interrupted = False interrupted = False
try: try:
while True: completion = self.send_with_retries(model, messages)
try:
completion = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0,
stream=True,
)
break
except RateLimitError as err:
retry_after = 1
self.io.tool_error(f"RateLimitError: {err}")
self.io.tool_error(f"Retry in {retry_after} seconds.")
time.sleep(retry_after)
self.show_send_output(completion, silent) self.show_send_output(completion, silent)
except KeyboardInterrupt: except KeyboardInterrupt:
interrupted = True interrupted = True

View file

@ -2,6 +2,9 @@ import os
import unittest import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import openai
import requests
from aider.coder import Coder from aider.coder import Coder
@ -117,6 +120,50 @@ class TestCoder(unittest.TestCase):
# Assert that the returned message is the expected one # Assert that the returned message is the expected one
self.assertEqual(result, 'a good "commit message"') self.assertEqual(result, 'a good "commit message"')
@patch("aider.coder.openai.ChatCompletion.create")
@patch("aider.coder.time.sleep")
def test_send_with_retries_rate_limit_error(self, mock_sleep, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder(io=mock_io, openai_api_key="fake_key")
# Set up the mock to raise RateLimitError on
# the first call and return None on the second call
mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"])
# Assert that time.sleep was called once
mock_sleep.assert_called_once()
@patch("aider.coder.openai.ChatCompletion.create")
@patch("aider.coder.time.sleep")
def test_send_with_retries_connection_error(self, mock_sleep, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder(io=mock_io, openai_api_key="fake_key")
# Set up the mock to raise ConnectionError on the first call
# and return None on the second call
mock_chat_completion_create.side_effect = [
requests.exceptions.ConnectionError("Connection error"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"])
# Assert that time.sleep was called once
mock_sleep.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()