refactor send_with_retries

This commit is contained in:
Paul Gauthier 2023-07-21 11:21:41 -03:00
parent e34e6ff897
commit 289887d94f
2 changed files with 48 additions and 41 deletions

View file

@ -9,12 +9,9 @@ import traceback
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
import backoff
import git import git
import openai import openai
import requests
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
from rich.console import Console, Text from rich.console import Console, Text
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
@ -22,6 +19,7 @@ from rich.markdown import Markdown
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
from ..dump import dump # noqa: F401 from ..dump import dump # noqa: F401
@ -555,43 +553,6 @@ class Coder:
return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames)) return prompts.added_files.format(fnames=", ".join(mentioned_rel_fnames))
@backoff.on_exception(
backoff.expo,
(
Timeout,
APIError,
ServiceUnavailableError,
RateLimitError,
requests.exceptions.ConnectionError,
),
max_tries=10,
on_backoff=lambda details: print(
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(self, model, messages, functions):
kwargs = dict(
model=model,
messages=messages,
temperature=0,
stream=self.stream,
)
if functions is not None:
kwargs["functions"] = self.functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
self.chat_completion_call_hashes.append(hash_object.hexdigest())
res = openai.ChatCompletion.create(**kwargs)
return res
def send(self, messages, model=None, silent=False, functions=None): def send(self, messages, model=None, silent=False, functions=None):
if not model: if not model:
model = self.main_model.name model = self.main_model.name
@ -601,7 +562,9 @@ class Coder:
interrupted = False interrupted = False
try: try:
completion = self.send_with_retries(model, messages, functions) hash_object, completion = send_with_retries(model, messages, functions, self.stream)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream: if self.stream:
self.show_send_output_stream(completion, silent) self.show_send_output_stream(completion, silent)
else: else:

44
aider/sendchat.py Normal file
View file

@ -0,0 +1,44 @@
import hashlib
import json
import backoff
import openai
import requests
from openai.error import APIError, RateLimitError, ServiceUnavailableError, Timeout
@backoff.on_exception(
backoff.expo,
(
Timeout,
APIError,
ServiceUnavailableError,
RateLimitError,
requests.exceptions.ConnectionError,
),
max_tries=10,
on_backoff=lambda details: print(
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(model, messages, functions, stream):
kwargs = dict(
model=model,
messages=messages,
temperature=0,
stream=stream,
)
if functions is not None:
kwargs["functions"] = functions
# we are abusing the openai object to stash these values
if hasattr(openai, "api_deployment_id"):
kwargs["deployment_id"] = openai.api_deployment_id
if hasattr(openai, "api_engine"):
kwargs["engine"] = openai.api_engine
# Generate SHA1 hash of kwargs and append it to chat_completion_call_hashes
hash_object = hashlib.sha1(json.dumps(kwargs, sort_keys=True).encode())
res = openai.ChatCompletion.create(**kwargs)
return hash_object, res