refactor send_with_retries

This commit is contained in:
Paul Gauthier 2023-07-21 11:21:41 -03:00
parent e34e6ff897
commit 289887d94f
2 changed files with 48 additions and 41 deletions

View file

@ -9,12 +9,9 @@ import traceback
from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath
import backoff
import git
import openai
import requests
from jsonschema import Draft7Validator
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
from rich.console import Console, Text
from rich.live import Live
from rich.markdown import Markdown
@ -22,6 +19,7 @@ from rich.markdown import Markdown
from aider import models, prompts, utils
from aider.commands import Commands
from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
from ..dump import dump # noqa: F401
@ -555,43 +553,6 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
@backoff.on_exception(
backoff.expo,
(
Timeout,
APIError,
ServiceUnavailableError,
RateLimitError,
requests.exceptions.ConnectionError,
),
max_tries=10,
on_backoff=lambda details: print(
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(self, model, messages, functions):
kwargs = dict(
model=model,
messages=messages,
temperature=0,
stream=self.stream,
)
if functions is not None:
kwargs["functions"] = self.functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
self.chat_completion_call_hashes.append(hash_object.hexdigest())
res = openai.ChatCompletion.create(**kwargs)
return res
def send(self, messages, model=None, silent=False, functions=None):
if not model:
model = self.main_model.name
@ -601,7 +562,9 @@ class Coder:
interrupted = False
try:
completion = self.send_with_retries(model, messages, functions)
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:
self.show_send_output_stream(completion, silent)
else: