mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-25 23:05:00 +00:00
Merge pull request #12 from paul-gauthier/handle-openai-exceptions
This commit is contained in:
commit
23f972fb2e
2 changed files with 67 additions and 15 deletions
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import git
|
import git
|
||||||
import openai
|
import openai
|
||||||
|
import requests
|
||||||
from openai.error import RateLimitError
|
from openai.error import RateLimitError
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
|
@ -390,6 +391,24 @@ class Coder:
|
||||||
|
|
||||||
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
|
||||||
|
|
||||||
|
def send_with_retries(self, model, messages):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return openai.ChatCompletion.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
except RateLimitError as err:
|
||||||
|
self.io.tool_error(f"RateLimitError: {err}")
|
||||||
|
except requests.exceptions.ConnectionError as err:
|
||||||
|
self.io.tool_error(f"ConnectionError: {err}")
|
||||||
|
|
||||||
|
retry_after = 1
|
||||||
|
self.io.tool_error(f"Retry in {retry_after} seconds.")
|
||||||
|
time.sleep(retry_after)
|
||||||
|
|
||||||
def send(self, messages, model=None, silent=False):
|
def send(self, messages, model=None, silent=False):
|
||||||
if not model:
|
if not model:
|
||||||
model = self.main_model
|
model = self.main_model
|
||||||
|
@ -397,21 +416,7 @@ class Coder:
|
||||||
self.resp = ""
|
self.resp = ""
|
||||||
interrupted = False
|
interrupted = False
|
||||||
try:
|
try:
|
||||||
while True:
|
completion = self.send_with_retries(model, messages)
|
||||||
try:
|
|
||||||
completion = openai.ChatCompletion.create(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=0,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except RateLimitError as err:
|
|
||||||
retry_after = 1
|
|
||||||
self.io.tool_error(f"RateLimitError: {err}")
|
|
||||||
self.io.tool_error(f"Retry in {retry_after} seconds.")
|
|
||||||
time.sleep(retry_after)
|
|
||||||
|
|
||||||
self.show_send_output(completion, silent)
|
self.show_send_output(completion, silent)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
interrupted = True
|
interrupted = True
|
||||||
|
|
|
@ -2,6 +2,9 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
from aider.coder import Coder
|
from aider.coder import Coder
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,6 +120,50 @@ class TestCoder(unittest.TestCase):
|
||||||
# Assert that the returned message is the expected one
|
# Assert that the returned message is the expected one
|
||||||
self.assertEqual(result, 'a good "commit message"')
|
self.assertEqual(result, 'a good "commit message"')
|
||||||
|
|
||||||
|
@patch("aider.coder.openai.ChatCompletion.create")
|
||||||
|
@patch("aider.coder.time.sleep")
|
||||||
|
def test_send_with_retries_rate_limit_error(self, mock_sleep, mock_chat_completion_create):
|
||||||
|
# Mock the IO object
|
||||||
|
mock_io = MagicMock()
|
||||||
|
|
||||||
|
# Initialize the Coder object with the mocked IO and mocked repo
|
||||||
|
coder = Coder(io=mock_io, openai_api_key="fake_key")
|
||||||
|
|
||||||
|
# Set up the mock to raise RateLimitError on
|
||||||
|
# the first call and return None on the second call
|
||||||
|
mock_chat_completion_create.side_effect = [
|
||||||
|
openai.error.RateLimitError("Rate limit exceeded"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the send_with_retries method
|
||||||
|
coder.send_with_retries("model", ["message"])
|
||||||
|
|
||||||
|
# Assert that time.sleep was called once
|
||||||
|
mock_sleep.assert_called_once()
|
||||||
|
|
||||||
|
@patch("aider.coder.openai.ChatCompletion.create")
|
||||||
|
@patch("aider.coder.time.sleep")
|
||||||
|
def test_send_with_retries_connection_error(self, mock_sleep, mock_chat_completion_create):
|
||||||
|
# Mock the IO object
|
||||||
|
mock_io = MagicMock()
|
||||||
|
|
||||||
|
# Initialize the Coder object with the mocked IO and mocked repo
|
||||||
|
coder = Coder(io=mock_io, openai_api_key="fake_key")
|
||||||
|
|
||||||
|
# Set up the mock to raise ConnectionError on the first call
|
||||||
|
# and return None on the second call
|
||||||
|
mock_chat_completion_create.side_effect = [
|
||||||
|
requests.exceptions.ConnectionError("Connection error"),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the send_with_retries method
|
||||||
|
coder.send_with_retries("model", ["message"])
|
||||||
|
|
||||||
|
# Assert that time.sleep was called once
|
||||||
|
mock_sleep.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue