feat: Add LiteLLMExceptions loading in test for send chat functionality

This commit is contained in:
Paul Gauthier 2024-11-07 13:09:47 -08:00 committed by Paul Gauthier (aider)
parent 4941a360cb
commit 8a3c95d8dd
2 changed files with 7 additions and 23 deletions

View file

@ -74,7 +74,3 @@ class LiteLLMExceptions:
def get_ex_info(self, ex):
"""Return the ExInfo for a given exception instance"""
return self.exceptions.get(ex.__class__, ExInfo(None, None, None))
litellm_ex = LiteLLMExceptions()
litellm_ex._load(strict=True)

View file

@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
import httpx
from aider.exceptions import LiteLLMExceptions
from aider.llm import litellm
from aider.sendchat import retry_exceptions, simple_send_with_retries
from aider.sendchat import simple_send_with_retries
class PrintCalled(Exception):
@ -12,9 +13,9 @@ class PrintCalled(Exception):
class TestSendChat(unittest.TestCase):
def test_retry_exceptions(self):
"""Test that retry_exceptions() can be called without raising errors"""
retry_exceptions() # Should not raise any exceptions
def test_litellm_exceptions(self):
litellm_ex = LiteLLMExceptions()
litellm_ex._load(strict=True)
@patch("litellm.completion")
@patch("builtins.print")
@ -24,7 +25,7 @@ class TestSendChat(unittest.TestCase):
# Set up the mock to raise
mock_completion.side_effect = [
litellm.exceptions.RateLimitError(
litellm.RateLimitError(
"rate limit exceeded",
response=mock,
llm_provider="llm_provider",
@ -35,17 +36,4 @@ class TestSendChat(unittest.TestCase):
# Call the simple_send_with_retries method
simple_send_with_retries("model", ["message"])
assert mock_print.call_count == 2
@patch("litellm.completion")
@patch("builtins.print")
def test_simple_send_with_retries_connection_error(self, mock_print, mock_completion):
# Set up the mock to raise
mock_completion.side_effect = [
httpx.ConnectError("Connection error"),
None,
]
# Call the simple_send_with_retries method
simple_send_with_retries("model", ["message"])
assert mock_print.call_count == 2
assert mock_print.call_count == 3