cleaned up client refs

This commit is contained in:
Paul Gauthier 2024-04-17 15:47:07 -07:00
parent f1a31d3944
commit c770fc4380
7 changed files with 16 additions and 30 deletions

View file

@ -42,7 +42,6 @@ def wrap_fence(name):
class Coder:
client = None
abs_fnames = None
repo = None
last_aider_commit_hash = None
@ -62,7 +61,6 @@ class Coder:
main_model=None,
edit_format=None,
io=None,
client=None,
**kwargs,
):
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
@ -74,17 +72,16 @@ class Coder:
edit_format = main_model.edit_format
if edit_format == "diff":
return EditBlockCoder(client, main_model, io, **kwargs)
return EditBlockCoder(main_model, io, **kwargs)
elif edit_format == "whole":
return WholeFileCoder(client, main_model, io, **kwargs)
return WholeFileCoder(main_model, io, **kwargs)
elif edit_format == "udiff":
return UnifiedDiffCoder(client, main_model, io, **kwargs)
return UnifiedDiffCoder(main_model, io, **kwargs)
else:
raise ValueError(f"Unknown edit format {edit_format}")
def __init__(
self,
client,
main_model,
io,
fnames=None,
@ -103,8 +100,6 @@ class Coder:
voice_language=None,
aider_ignore_file=None,
):
self.client = client
if not fnames:
fnames = []
@ -217,7 +212,6 @@ class Coder:
self.io.tool_output(f"Added {fname} to the chat.")
self.summarizer = ChatSummary(
self.client,
self.main_model.weak_model(),
self.main_model.max_chat_history_tokens,
)
@ -368,7 +362,7 @@ class Coder:
return files_messages
def get_images_message(self):
if not utils.is_gpt4_with_openai_base_url(self.main_model.name, self.client):
if not utils.is_gpt4_with_openai_base_url(self.main_model.name):
return None
image_messages = []
@ -650,9 +644,7 @@ class Coder:
interrupted = False
try:
hash_object, completion = send_with_retries(
self.client, model, messages, functions, self.stream
)
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
self.chat_completion_call_hashes.append(hash_object.hexdigest())
if self.stream:

View file

@ -7,8 +7,7 @@ from aider.sendchat import simple_send_with_retries
class ChatSummary:
def __init__(self, client, model=None, max_tokens=1024):
self.client = client
def __init__(self, model=None, max_tokens=1024):
self.tokenizer = model.tokenizer
self.max_tokens = max_tokens
self.model = model
@ -85,7 +84,7 @@ class ChatSummary:
dict(role="user", content=content),
]
summary = simple_send_with_retries(self.client, self.model.name, messages)
summary = simple_send_with_retries(self.model.name, messages)
if summary is None:
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
summary = prompts.summary_prefix + summary

View file

@ -587,7 +587,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
main_model=main_model,
edit_format=args.edit_format,
io=io,
client=None,
##
fnames=fnames,
git_dname=git_dname,

View file

@ -121,7 +121,7 @@ class GitRepo:
]
for model in self.models:
commit_message = simple_send_with_retries(None, model.name, messages)
commit_message = simple_send_with_retries(model.name, messages)
if commit_message:
break

View file

@ -30,7 +30,7 @@ CACHE = None
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
),
)
def send_with_retries(client, model_name, messages, functions, stream):
def send_with_retries(model_name, messages, functions, stream):
kwargs = dict(
model=model_name,
messages=messages,
@ -41,7 +41,7 @@ def send_with_retries(client, model_name, messages, functions, stream):
kwargs["functions"] = functions
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
if client and is_gpt4_with_openai_base_url(model_name, client):
if is_gpt4_with_openai_base_url(model_name):
if any(
isinstance(msg.get("content"), list)
and any("image_url" in item for item in msg.get("content") if isinstance(item, dict))
@ -67,10 +67,9 @@ def send_with_retries(client, model_name, messages, functions, stream):
return hash_object, res
def simple_send_with_retries(client, model_name, messages):
def simple_send_with_retries(model_name, messages):
try:
_hash, response = send_with_retries(
client=client,
model_name=model_name,
messages=messages,
functions=None,

View file

@ -106,14 +106,12 @@ def show_messages(messages, title=None, functions=None):
dump(functions)
def is_gpt4_with_openai_base_url(model_name, client):
# TODO: fix this
def is_gpt4_with_openai_base_url(model_name):
"""
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
:param model_name: The name of the model to check.
:param client: The OpenAI client instance.
:return: True if conditions are met, False otherwise.
"""
if client is None or not hasattr(client, "base_url"):
return False
return model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host
return model_name.startswith("gpt-4")

View file

@ -26,7 +26,7 @@ class Voice:
threshold = 0.15
def __init__(self, client):
def __init__(self):
if sf is None:
raise SoundDeviceError
try:
@ -37,8 +37,6 @@ class Voice:
except (OSError, ModuleNotFoundError):
raise SoundDeviceError
self.client = client
def callback(self, indata, frames, time, status):
"""This is called (from a separate thread) for each audio block."""
rms = np.sqrt(np.mean(indata**2))
@ -88,6 +86,7 @@ class Voice:
while not self.q.empty():
file.write(self.q.get())
# TODO: fix client!
with open(filename, "rb") as fh:
transcript = self.client.audio.transcriptions.create(
model="whisper-1", file=fh, prompt=history, language=language