mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
cleaned up client refs
This commit is contained in:
parent
f1a31d3944
commit
c770fc4380
7 changed files with 16 additions and 30 deletions
|
@ -42,7 +42,6 @@ def wrap_fence(name):
|
||||||
|
|
||||||
|
|
||||||
class Coder:
|
class Coder:
|
||||||
client = None
|
|
||||||
abs_fnames = None
|
abs_fnames = None
|
||||||
repo = None
|
repo = None
|
||||||
last_aider_commit_hash = None
|
last_aider_commit_hash = None
|
||||||
|
@ -62,7 +61,6 @@ class Coder:
|
||||||
main_model=None,
|
main_model=None,
|
||||||
edit_format=None,
|
edit_format=None,
|
||||||
io=None,
|
io=None,
|
||||||
client=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
|
from . import EditBlockCoder, UnifiedDiffCoder, WholeFileCoder
|
||||||
|
@ -74,17 +72,16 @@ class Coder:
|
||||||
edit_format = main_model.edit_format
|
edit_format = main_model.edit_format
|
||||||
|
|
||||||
if edit_format == "diff":
|
if edit_format == "diff":
|
||||||
return EditBlockCoder(client, main_model, io, **kwargs)
|
return EditBlockCoder(main_model, io, **kwargs)
|
||||||
elif edit_format == "whole":
|
elif edit_format == "whole":
|
||||||
return WholeFileCoder(client, main_model, io, **kwargs)
|
return WholeFileCoder(main_model, io, **kwargs)
|
||||||
elif edit_format == "udiff":
|
elif edit_format == "udiff":
|
||||||
return UnifiedDiffCoder(client, main_model, io, **kwargs)
|
return UnifiedDiffCoder(main_model, io, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown edit format {edit_format}")
|
raise ValueError(f"Unknown edit format {edit_format}")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client,
|
|
||||||
main_model,
|
main_model,
|
||||||
io,
|
io,
|
||||||
fnames=None,
|
fnames=None,
|
||||||
|
@ -103,8 +100,6 @@ class Coder:
|
||||||
voice_language=None,
|
voice_language=None,
|
||||||
aider_ignore_file=None,
|
aider_ignore_file=None,
|
||||||
):
|
):
|
||||||
self.client = client
|
|
||||||
|
|
||||||
if not fnames:
|
if not fnames:
|
||||||
fnames = []
|
fnames = []
|
||||||
|
|
||||||
|
@ -217,7 +212,6 @@ class Coder:
|
||||||
self.io.tool_output(f"Added {fname} to the chat.")
|
self.io.tool_output(f"Added {fname} to the chat.")
|
||||||
|
|
||||||
self.summarizer = ChatSummary(
|
self.summarizer = ChatSummary(
|
||||||
self.client,
|
|
||||||
self.main_model.weak_model(),
|
self.main_model.weak_model(),
|
||||||
self.main_model.max_chat_history_tokens,
|
self.main_model.max_chat_history_tokens,
|
||||||
)
|
)
|
||||||
|
@ -368,7 +362,7 @@ class Coder:
|
||||||
return files_messages
|
return files_messages
|
||||||
|
|
||||||
def get_images_message(self):
|
def get_images_message(self):
|
||||||
if not utils.is_gpt4_with_openai_base_url(self.main_model.name, self.client):
|
if not utils.is_gpt4_with_openai_base_url(self.main_model.name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_messages = []
|
image_messages = []
|
||||||
|
@ -650,9 +644,7 @@ class Coder:
|
||||||
|
|
||||||
interrupted = False
|
interrupted = False
|
||||||
try:
|
try:
|
||||||
hash_object, completion = send_with_retries(
|
hash_object, completion = send_with_retries(model, messages, functions, self.stream)
|
||||||
self.client, model, messages, functions, self.stream
|
|
||||||
)
|
|
||||||
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
|
|
|
@ -7,8 +7,7 @@ from aider.sendchat import simple_send_with_retries
|
||||||
|
|
||||||
|
|
||||||
class ChatSummary:
|
class ChatSummary:
|
||||||
def __init__(self, client, model=None, max_tokens=1024):
|
def __init__(self, model=None, max_tokens=1024):
|
||||||
self.client = client
|
|
||||||
self.tokenizer = model.tokenizer
|
self.tokenizer = model.tokenizer
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -85,7 +84,7 @@ class ChatSummary:
|
||||||
dict(role="user", content=content),
|
dict(role="user", content=content),
|
||||||
]
|
]
|
||||||
|
|
||||||
summary = simple_send_with_retries(self.client, self.model.name, messages)
|
summary = simple_send_with_retries(self.model.name, messages)
|
||||||
if summary is None:
|
if summary is None:
|
||||||
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
raise ValueError(f"summarizer unexpectedly failed for {self.model.name}")
|
||||||
summary = prompts.summary_prefix + summary
|
summary = prompts.summary_prefix + summary
|
||||||
|
|
|
@ -587,7 +587,6 @@ def main(argv=None, input=None, output=None, force_git_root=None):
|
||||||
main_model=main_model,
|
main_model=main_model,
|
||||||
edit_format=args.edit_format,
|
edit_format=args.edit_format,
|
||||||
io=io,
|
io=io,
|
||||||
client=None,
|
|
||||||
##
|
##
|
||||||
fnames=fnames,
|
fnames=fnames,
|
||||||
git_dname=git_dname,
|
git_dname=git_dname,
|
||||||
|
|
|
@ -121,7 +121,7 @@ class GitRepo:
|
||||||
]
|
]
|
||||||
|
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
commit_message = simple_send_with_retries(None, model.name, messages)
|
commit_message = simple_send_with_retries(model.name, messages)
|
||||||
if commit_message:
|
if commit_message:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ CACHE = None
|
||||||
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
f"{details.get('exception','Exception')}\nRetry in {details['wait']:.1f} seconds."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def send_with_retries(client, model_name, messages, functions, stream):
|
def send_with_retries(model_name, messages, functions, stream):
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -41,7 +41,7 @@ def send_with_retries(client, model_name, messages, functions, stream):
|
||||||
kwargs["functions"] = functions
|
kwargs["functions"] = functions
|
||||||
|
|
||||||
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
|
# Check conditions to switch to gpt-4-vision-preview or strip out image_url messages
|
||||||
if client and is_gpt4_with_openai_base_url(model_name, client):
|
if is_gpt4_with_openai_base_url(model_name):
|
||||||
if any(
|
if any(
|
||||||
isinstance(msg.get("content"), list)
|
isinstance(msg.get("content"), list)
|
||||||
and any("image_url" in item for item in msg.get("content") if isinstance(item, dict))
|
and any("image_url" in item for item in msg.get("content") if isinstance(item, dict))
|
||||||
|
@ -67,10 +67,9 @@ def send_with_retries(client, model_name, messages, functions, stream):
|
||||||
return hash_object, res
|
return hash_object, res
|
||||||
|
|
||||||
|
|
||||||
def simple_send_with_retries(client, model_name, messages):
|
def simple_send_with_retries(model_name, messages):
|
||||||
try:
|
try:
|
||||||
_hash, response = send_with_retries(
|
_hash, response = send_with_retries(
|
||||||
client=client,
|
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
functions=None,
|
functions=None,
|
||||||
|
|
|
@ -106,14 +106,12 @@ def show_messages(messages, title=None, functions=None):
|
||||||
dump(functions)
|
dump(functions)
|
||||||
|
|
||||||
|
|
||||||
def is_gpt4_with_openai_base_url(model_name, client):
|
# TODO: fix this
|
||||||
|
def is_gpt4_with_openai_base_url(model_name):
|
||||||
"""
|
"""
|
||||||
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
|
Check if the model_name starts with 'gpt-4' and the client base URL includes 'api.openai.com'.
|
||||||
|
|
||||||
:param model_name: The name of the model to check.
|
:param model_name: The name of the model to check.
|
||||||
:param client: The OpenAI client instance.
|
|
||||||
:return: True if conditions are met, False otherwise.
|
:return: True if conditions are met, False otherwise.
|
||||||
"""
|
"""
|
||||||
if client is None or not hasattr(client, "base_url"):
|
return model_name.startswith("gpt-4")
|
||||||
return False
|
|
||||||
return model_name.startswith("gpt-4") and "api.openai.com" in client.base_url.host
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class Voice:
|
||||||
|
|
||||||
threshold = 0.15
|
threshold = 0.15
|
||||||
|
|
||||||
def __init__(self, client):
|
def __init__(self):
|
||||||
if sf is None:
|
if sf is None:
|
||||||
raise SoundDeviceError
|
raise SoundDeviceError
|
||||||
try:
|
try:
|
||||||
|
@ -37,8 +37,6 @@ class Voice:
|
||||||
except (OSError, ModuleNotFoundError):
|
except (OSError, ModuleNotFoundError):
|
||||||
raise SoundDeviceError
|
raise SoundDeviceError
|
||||||
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
def callback(self, indata, frames, time, status):
|
def callback(self, indata, frames, time, status):
|
||||||
"""This is called (from a separate thread) for each audio block."""
|
"""This is called (from a separate thread) for each audio block."""
|
||||||
rms = np.sqrt(np.mean(indata**2))
|
rms = np.sqrt(np.mean(indata**2))
|
||||||
|
@ -88,6 +86,7 @@ class Voice:
|
||||||
while not self.q.empty():
|
while not self.q.empty():
|
||||||
file.write(self.q.get())
|
file.write(self.q.get())
|
||||||
|
|
||||||
|
# TODO: fix client!
|
||||||
with open(filename, "rb") as fh:
|
with open(filename, "rb") as fh:
|
||||||
transcript = self.client.audio.transcriptions.create(
|
transcript = self.client.audio.transcriptions.create(
|
||||||
model="whisper-1", file=fh, prompt=history, language=language
|
model="whisper-1", file=fh, prompt=history, language=language
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue