mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
cleaned up client refs
This commit is contained in:
parent
f1a31d3944
commit
c770fc4380
7 changed files with 16 additions and 30 deletions
|
@ -42,7 +42,6 @@ def wrap_fence(name):
|
|||
|
||||
|
||||
class Coder:
|
||||
client = None
|
||||
abs_fnames = None
|
||||
repo = None
|
||||
last_aider_commit_hash = None
|
||||
|
@ -62,7 +61,6 @@ class Coder:
|
|||
main_model=None,
|
||||
edit_format=None,
|
||||
io=None,
|
||||
client=None,
|
||||
**kwargs,
|
||||
):
|
||||
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
|
||||
|
@ -74,17 +72,16 @@ class Coder:
|
|||
edit_format = main_model.edit_format
|
||||
|
||||
if edit_format == "diff":
|
||||
return EditBlockCoder(client, main_model, io, **kwargs)
|
||||
return EditBlockCoder(main_model, io, **kwargs)
|
||||
elif edit_format == "whole":
|
||||
return WholeFileCoder(client, main_model, io, **kwargs)
|
||||
return WholeFileCoder(main_model, io, **kwargs)
|
||||
elif edit_format == "udiff":
|
||||
return UnifiedDiffCoder(client, main_model, io, **kwargs)
|
||||
return UnifiedDiffCoder(main_model, io, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown edit format {edit_format}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client,
|
||||
main_model,
|
||||
io,
|
||||
fnames=None,
|
||||
|
@ -103,8 +100,6 @@ class Coder:
|
|||
voice_language=None,
|
||||
aider_ignore_file=None,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
if not fnames:
|
||||
fnames = []
|
||||
|
||||
|
@ -217,7 +212,6 @@ class Coder:
|
|||
self.io.tool_output(f"Added {fname} to the chat.")
|
||||
|
||||
self.summarizer = ChatSummary(
|
||||
self.client,
|
||||
self.main_model.weak_model(),
|
||||
self.main_model.max_chat_history_tokens,
|
||||
)
|
||||
|
@ -368,7 +362,7 @@ class Coder:
|
|||
return files_messages
|
||||
|
||||
def get_images_message(self):
|
||||
if not utils.is_gpt4_with_openai_base_url(self.main_model.name, self.client):
|
||||
if not utils.is_gpt4_with_openai_base_url(self.main_model.name):
|
||||
return None
|
||||
|
||||
image_messages = []
|
||||
|
@ -650,9 +644,7 @@ class Coder:
|
|||
|
||||
interrupted = False
|
||||
try:
|
||||
hash_object, completion = send_with_retries(
|
||||
self.client, model, messages, functions, self.stream
|
||||
)
|
||||
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
|
||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||
|
||||
if self.stream:
|
||||
|
|
|
@ -7,8 +7,7 @@ from aider.sendchat import simple_send_with_retries
|
|||
|
||||
|
||||
class ChatSummary:
|
||||
def __init__(self, client, model=None, max_tokens=1024):
|
||||
self.client = client
|
||||
def __init__(self, model=None, max_tokens=1024):
|
||||
self.tokenizer = model.tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.model = model
|
||||
|
@ -85,7 +84,7 @@ class ChatSummary:
|
|||
dict(role="user", content=content),
|
||||
]
|
||||
|
||||
summary = simple_send_with_retries(self.client, self.model.name, messages)
|
||||
summary = simple_send_with_retries(self.model.name, messages)
|
||||
if summary is None:
|
||||
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
||||
summary = prompts.summary_prefix + summary
|
||||
|
|
|
@ -587,7 +587,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
|||
main_model=main_model,
|
||||
edit_format=args.edit_format,
|
||||
io=io,
|
||||
client=None,
|
||||
##
|
||||
fnames=fnames,
|
||||
git_dname=git_dname,
|
||||
|
|
|
@ -121,7 +121,7 @@ class GitRepo:
|
|||
]
|
||||
|
||||
for model in self.models:
|
||||
commit_message = simple_send_with_retries(None, model.name, messages)
|
||||
commit_message = simple_send_with_retries(model.name, messages)
|
||||
if commit_message:
|
||||
break
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ CACHE = None
|
|||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||
),
|
||||
)
|
||||
def send_with_retries(client, model_name, messages, functions, stream):
|
||||
def send_with_retries(model_name, messages, functions, stream):
|
||||
kwargs = dict(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
|
@ -41,7 +41,7 @@ def send_with_retries(client, model_name, messages, functions, stream):
|
|||
kwargs["functions"] = functions
|
||||
|
||||
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
|
||||
if client and is_gpt4_with_openai_base_url(model_name, client):
|
||||
if is_gpt4_with_openai_base_url(model_name):
|
||||
if any(
|
||||
isinstance(msg.get("content"), list)
|
||||
and any("image_url" in item for item in msg.get("content") if isinstance(item, dict))
|
||||
|
@ -67,10 +67,9 @@ def send_with_retries(client, model_name, messages, functions, stream):
|
|||
return hash_object, res
|
||||
|
||||
|
||||
def simple_send_with_retries(client, model_name, messages):
|
||||
def simple_send_with_retries(model_name, messages):
|
||||
try:
|
||||
_hash, response = send_with_retries(
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
messages=messages,
|
||||
functions=None,
|
||||
|
|
|
@ -106,14 +106,12 @@ def show_messages(messages, title=None, functions=None):
|
|||
dump(functions)
|
||||
|
||||
|
||||
def is_gpt4_with_openai_base_url(model_name, client):
|
||||
# TODO: fix this
|
||||
def is_gpt4_with_openai_base_url(model_name):
|
||||
"""
|
||||
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
|
||||
|
||||
:param model_name: The name of the model to check.
|
||||
:param client: The OpenAI client instance.
|
||||
:return: True if conditions are met, False otherwise.
|
||||
"""
|
||||
if client is None or not hasattr(client, "base_url"):
|
||||
return False
|
||||
return model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host
|
||||
return model_name.startswith("gpt-4")
|
||||
|
|
|
@ -26,7 +26,7 @@ class Voice:
|
|||
|
||||
threshold = 0.15
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self):
|
||||
if sf is None:
|
||||
raise SoundDeviceError
|
||||
try:
|
||||
|
@ -37,8 +37,6 @@ class Voice:
|
|||
except (OSError, ModuleNotFoundError):
|
||||
raise SoundDeviceError
|
||||
|
||||
self.client = client
|
||||
|
||||
def callback(self, indata, frames, time, status):
|
||||
"""This is called (from a separate thread) for each audio block."""
|
||||
rms = np.sqrt(np.mean(indata**2))
|
||||
|
@ -88,6 +86,7 @@ class Voice:
|
|||
while not self.q.empty():
|
||||
file.write(self.q.get())
|
||||
|
||||
# TODO: fix client!
|
||||
with open(filename, "rb") as fh:
|
||||
transcript = self.client.audio.transcriptions.create(
|
||||
model="whisper-1", file=fh, prompt=history, language=language
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue