Merge pull request #12 from paul-gauthier/handle-openai-exceptions

This commit is contained in:
paul-gauthier 2023-06-03 06:31:00 -07:00 committed by GitHub
commit 23f972fb2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 67 additions and 15 deletions

View file

@ -8,6 +8,7 @@ from pathlib import Path
import git
import openai
import requests
from openai.error import RateLimitError
from rich.console import Console
from rich.live import Live
@ -390,6 +391,24 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
def send_with_retries(self, model, messages):
while True:
try:
return openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0,
stream=True,
)
except RateLimitError as err:
self.io.tool_error(f"RateLimitError: {err}")
except requests.exceptions.ConnectionError as err:
self.io.tool_error(f"ConnectionError: {err}")
retry_after = 1
self.io.tool_error(f"Retry in {retry_after} seconds.")
time.sleep(retry_after)
def send(self, messages, model=None, silent=False):
if not model:
model = self.main_model
@ -397,21 +416,7 @@ class Coder:
self.resp = ""
interrupted = False
try:
while True:
try:
completion = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0,
stream=True,
)
break
except RateLimitError as err:
retry_after = 1
self.io.tool_error(f"RateLimitError: {err}")
self.io.tool_error(f"Retry in {retry_after} seconds.")
time.sleep(retry_after)
completion = self.send_with_retries(model, messages)
self.show_send_output(completion, silent)
except KeyboardInterrupt:
interrupted = True

View file

@ -2,6 +2,9 @@ import os
import unittest
from unittest.mock import MagicMock, patch
import openai
import requests
from aider.coder import Coder
@ -117,6 +120,50 @@ class TestCoder(unittest.TestCase):
# Assert that the returned message is the expected one
self.assertEqual(result, 'a good "commit message"')
@patch("aider.coder.openai.ChatCompletion.create")
@patch("aider.coder.time.sleep")
def test_send_with_retries_rate_limit_error(self, mock_sleep, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder(io=mock_io, openai_api_key="fake_key")
# Set up the mock to raise RateLimitError on
# the first call and return None on the second call
mock_chat_completion_create.side_effect = [
openai.error.RateLimitError("Rate limit exceeded"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"])
# Assert that time.sleep was called once
mock_sleep.assert_called_once()
@patch("aider.coder.openai.ChatCompletion.create")
@patch("aider.coder.time.sleep")
def test_send_with_retries_connection_error(self, mock_sleep, mock_chat_completion_create):
# Mock the IO object
mock_io = MagicMock()
# Initialize the Coder object with the mocked IO and mocked repo
coder = Coder(io=mock_io, openai_api_key="fake_key")
# Set up the mock to raise ConnectionError on the first call
# and return None on the second call
mock_chat_completion_create.side_effect = [
requests.exceptions.ConnectionError("Connection error"),
None,
]
# Call the send_with_retries method
coder.send_with_retries("model", ["message"])
# Assert that time.sleep was called once
mock_sleep.assert_called_once()
if __name__ == "__main__":
unittest.main()