Merge branch 'main' into json-coders

This commit is contained in:
Paul Gauthier 2024-08-13 17:03:30 -07:00
commit e1b83ba6b5
81 changed files with 2210 additions and 720 deletions

View file

@ -30,7 +30,7 @@ from aider.llm import litellm
from aider.mdstream import MarkdownStream
from aider.repo import GitRepo
from aider.repomap import RepoMap
from aider.sendchat import send_with_retries
from aider.sendchat import retry_exceptions, send_completion
from aider.utils import format_content, format_messages, is_image_file
from ..dump import dump # noqa: F401
@ -50,6 +50,7 @@ def wrap_fence(name):
class Coder:
abs_fnames = None
abs_read_only_fnames = None
repo = None
last_aider_commit_hash = None
aider_edited_files = None
@ -70,6 +71,11 @@ class Coder:
lint_outcome = None
test_outcome = None
multi_response_content = ""
partial_response_content = ""
commit_before_message = []
message_cost = 0.0
message_tokens_sent = 0
message_tokens_received = 0
@classmethod
def create(
@ -89,6 +95,8 @@ class Coder:
else:
main_model = models.Model(models.DEFAULT_MODEL_NAME)
if edit_format == "code":
edit_format = None
if edit_format is None:
if from_coder:
edit_format = from_coder.edit_format
@ -112,6 +120,7 @@ class Coder:
# Bring along context from the old Coder
update = dict(
fnames=list(from_coder.abs_fnames),
read_only_fnames=list(from_coder.abs_read_only_fnames), # Copy read-only files
done_messages=done_messages,
cur_messages=from_coder.cur_messages,
aider_commit_hashes=from_coder.aider_commit_hashes,
@ -143,7 +152,10 @@ class Coder:
main_model = self.main_model
weak_model = main_model.weak_model
prefix = "Model:"
output = f" {main_model.name} with {self.edit_format} edit format"
output = f" {main_model.name} with"
if main_model.info.get("supports_assistant_prefill"):
output += " ♾️"
output += f" {self.edit_format} edit format"
if weak_model is not main_model:
prefix = "Models:"
output += f", weak model {weak_model.name}"
@ -193,7 +205,7 @@ class Coder:
io,
repo=None,
fnames=None,
pretty=True,
read_only_fnames=None,
show_diffs=False,
auto_commits=True,
dirty_commits=True,
@ -217,6 +229,7 @@ class Coder:
summarizer=None,
total_cost=0.0,
):
self.commit_before_message = []
self.aider_commit_hashes = set()
self.rejected_urls = set()
self.abs_root_path_cache = {}
@ -240,6 +253,7 @@ class Coder:
self.verbose = verbose
self.abs_fnames = set()
self.abs_read_only_fnames = set()
if cur_messages:
self.cur_messages = cur_messages
@ -263,9 +277,9 @@ class Coder:
self.code_theme = code_theme
self.dry_run = dry_run
self.pretty = pretty
self.pretty = self.io.pretty
if pretty:
if self.pretty:
self.console = Console()
else:
self.console = Console(force_terminal=False, no_color=True)
@ -314,6 +328,15 @@ class Coder:
if not self.repo:
self.find_common_root()
if read_only_fnames:
self.abs_read_only_fnames = set()
for fname in read_only_fnames:
abs_fname = self.abs_root_path(fname)
if os.path.exists(abs_fname):
self.abs_read_only_fnames.add(abs_fname)
else:
self.io.tool_error(f"Error: Read-only file {fname} does not exist. Skipping.")
if map_tokens is None:
use_repo_map = main_model.use_repo_map
map_tokens = 1024
@ -376,8 +399,10 @@ class Coder:
self.linter.set_linter(lang, cmd)
def show_announcements(self):
bold = True
for line in self.get_announcements():
self.io.tool_output(line)
self.io.tool_output(line, bold=bold)
bold = False
def find_common_root(self):
if len(self.abs_fnames) == 1:
@ -444,6 +469,10 @@ class Coder:
all_content = ""
for _fname, content in self.get_abs_fnames_content():
all_content += content + "\n"
for _fname in self.abs_read_only_fnames:
content = self.io.read_text(_fname)
if content is not None:
all_content += content + "\n"
good = False
for fence_open, fence_close in self.fences:
@ -485,6 +514,19 @@ class Coder:
return prompt
def get_read_only_files_content(self):
prompt = ""
for fname in self.abs_read_only_fnames:
content = self.io.read_text(fname)
if content is not None and not is_image_file(fname):
relative_fname = self.get_rel_fname(fname)
prompt += "\n"
prompt += relative_fname
prompt += f"\n{self.fence[0]}\n"
prompt += content
prompt += f"{self.fence[1]}\n"
return prompt
def get_cur_message_text(self):
text = ""
for msg in self.cur_messages:
@ -522,9 +564,13 @@ class Coder:
mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents))
other_files = set(self.get_all_abs_files()) - set(self.abs_fnames)
all_abs_files = set(self.get_all_abs_files())
repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files
chat_files = set(self.abs_fnames) | repo_abs_read_only_fnames
other_files = all_abs_files - chat_files
repo_content = self.repo_map.get_repo_map(
self.abs_fnames,
chat_files,
other_files,
mentioned_fnames=mentioned_fnames,
mentioned_idents=mentioned_idents,
@ -534,7 +580,7 @@ class Coder:
if not repo_content:
repo_content = self.repo_map.get_repo_map(
set(),
set(self.get_all_abs_files()),
all_abs_files,
mentioned_fnames=mentioned_fnames,
mentioned_idents=mentioned_idents,
)
@ -543,7 +589,7 @@ class Coder:
if not repo_content:
repo_content = self.repo_map.get_repo_map(
set(),
set(self.get_all_abs_files()),
all_abs_files,
)
return repo_content
@ -572,12 +618,6 @@ class Coder:
files_content = self.gpt_prompts.files_no_full_files
files_reply = "Ok."
if files_content:
files_messages += [
dict(role="user", content=files_content),
dict(role="assistant", content=files_reply),
]
images_message = self.get_images_message()
if images_message is not None:
files_messages += [
@ -585,6 +625,24 @@ class Coder:
dict(role="assistant", content="Ok."),
]
read_only_content = self.get_read_only_files_content()
if read_only_content:
files_messages += [
dict(
role="user", content=self.gpt_prompts.read_only_files_prefix + read_only_content
),
dict(
role="assistant",
content="Ok, I will use these files as references.",
),
]
if files_content:
files_messages += [
dict(role="user", content=files_content),
dict(role="assistant", content=files_reply),
]
return files_messages
def get_images_message(self):
@ -597,9 +655,11 @@ class Coder:
mime_type, _ = mimetypes.guess_type(fname)
if mime_type and mime_type.startswith("image/"):
image_url = f"data:{mime_type};base64,{content}"
image_messages.append(
{"type": "image_url", "image_url": {"url": image_url, "detail": "high"}}
)
rel_fname = self.get_rel_fname(fname)
image_messages += [
{"type": "text", "text": f"Image file: {rel_fname}"},
{"type": "image_url", "image_url": {"url": image_url, "detail": "high"}},
]
if not image_messages:
return None
@ -609,7 +669,7 @@ class Coder:
def run_stream(self, user_message):
self.io.user_input(user_message)
self.init_before_message()
yield from self.send_new_user_message(user_message)
yield from self.send_message(user_message)
def init_before_message(self):
self.reflected_message = None
@ -617,48 +677,39 @@ class Coder:
self.lint_outcome = None
self.test_outcome = None
self.edit_outcome = None
if self.repo:
self.commit_before_message.append(self.repo.get_head())
def run(self, with_message=None):
while True:
self.init_before_message()
def run(self, with_message=None, preproc=True):
try:
if with_message:
self.io.user_input(with_message)
self.run_one(with_message, preproc)
return self.partial_response_content
try:
if with_message:
new_user_message = with_message
self.io.user_input(with_message)
else:
new_user_message = self.run_loop()
while True:
try:
user_message = self.get_input()
self.run_one(user_message, preproc)
self.show_undo_hint()
except KeyboardInterrupt:
self.keyboard_interrupt()
except EOFError:
return
while new_user_message:
self.reflected_message = None
list(self.send_new_user_message(new_user_message))
new_user_message = None
if self.reflected_message:
if self.num_reflections < self.max_reflections:
self.num_reflections += 1
new_user_message = self.reflected_message
else:
self.io.tool_error(
f"Only {self.max_reflections} reflections allowed, stopping."
)
if with_message:
return self.partial_response_content
except KeyboardInterrupt:
self.keyboard_interrupt()
except EOFError:
return
def run_loop(self):
inp = self.io.get_input(
def get_input(self):
inchat_files = self.get_inchat_relative_files()
read_only_files = [self.get_rel_fname(fname) for fname in self.abs_read_only_fnames]
all_files = sorted(set(inchat_files + read_only_files))
return self.io.get_input(
self.root,
self.get_inchat_relative_files(),
all_files,
self.get_addable_relative_files(),
self.commands,
self.abs_read_only_fnames,
)
def preproc_user_input(self, inp):
if not inp:
return
@ -670,6 +721,28 @@ class Coder:
return inp
def run_one(self, user_message, preproc):
self.init_before_message()
if preproc:
message = self.preproc_user_input(user_message)
else:
message = user_message
while message:
self.reflected_message = None
list(self.send_message(message))
if not self.reflected_message:
break
if self.num_reflections >= self.max_reflections:
self.io.tool_error(f"Only {self.max_reflections} reflections allowed, stopping.")
return
self.num_reflections += 1
message = self.reflected_message
def check_for_urls(self, inp):
url_pattern = re.compile(r"(https?://[^\s/$.?#].[^\s]*[^\s,.])")
urls = list(set(url_pattern.findall(inp))) # Use set to remove duplicates
@ -678,7 +751,7 @@ class Coder:
if url not in self.rejected_urls:
if self.io.confirm_ask(f"Add {url} to the chat?"):
inp += "\n\n"
inp += self.commands.cmd_web(url)
inp += self.commands.cmd_web(url, paginate=False)
added_urls.append(url)
else:
self.rejected_urls.add(url)
@ -826,6 +899,7 @@ class Coder:
self.summarize_end()
messages += self.done_messages
messages += self.get_files_messages()
if self.gpt_prompts.system_reminder:
@ -852,7 +926,7 @@ class Coder:
final = messages[-1]
max_input_tokens = self.main_model.info.get("max_input_tokens")
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
# Add the reminder prompt if we still have room to include it.
if (
max_input_tokens is None
@ -872,7 +946,7 @@ class Coder:
return messages
def send_new_user_message(self, inp):
def send_message(self, inp):
self.aider_edited_files = None
self.cur_messages += [
@ -891,6 +965,8 @@ class Coder:
else:
self.mdstream = None
retry_delay = 0.125
self.usage_report = None
exhausted = False
interrupted = False
@ -899,6 +975,14 @@ class Coder:
try:
yield from self.send(messages, functions=self.functions)
break
except retry_exceptions() as err:
self.io.tool_error(str(err))
retry_delay *= 2
if retry_delay > 60:
break
self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...")
time.sleep(retry_delay)
continue
except KeyboardInterrupt:
interrupted = True
break
@ -911,7 +995,7 @@ class Coder:
return
except FinishReasonLength:
# We hit the output limit!
if not self.main_model.can_prefill:
if not self.main_model.info.get("supports_assistant_prefill"):
exhausted = True
break
@ -920,7 +1004,9 @@ class Coder:
if messages[-1]["role"] == "assistant":
messages[-1]["content"] = self.multi_response_content
else:
messages.append(dict(role="assistant", content=self.multi_response_content))
messages.append(
dict(role="assistant", content=self.multi_response_content, prefix=True)
)
except Exception as err:
self.io.tool_error(f"Unexpected error: {err}")
traceback.print_exc()
@ -935,8 +1021,7 @@ class Coder:
self.io.tool_output()
if self.usage_report:
self.io.tool_output(self.usage_report)
self.show_usage_report()
if exhausted:
self.show_exhausted_error()
@ -1011,10 +1096,10 @@ class Coder:
output_tokens = 0
if self.partial_response_content:
output_tokens = self.main_model.token_count(self.partial_response_content)
max_output_tokens = self.main_model.info.get("max_output_tokens", 0)
max_output_tokens = self.main_model.info.get("max_output_tokens") or 0
input_tokens = self.main_model.token_count(self.format_messages())
max_input_tokens = self.main_model.info.get("max_input_tokens", 0)
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
total_tokens = input_tokens + output_tokens
@ -1159,9 +1244,8 @@ class Coder:
self.io.log_llm_history("TO LLM", format_messages(messages))
interrupted = False
try:
hash_object, completion = send_with_retries(
hash_object, completion = send_completion(
model.name,
messages,
functions,
@ -1176,9 +1260,9 @@ class Coder:
yield from self.show_send_output_stream(completion)
else:
self.show_send_output(completion)
except KeyboardInterrupt:
except KeyboardInterrupt as kbi:
self.keyboard_interrupt()
interrupted = True
raise kbi
finally:
self.io.log_llm_history(
"LLM RESPONSE",
@ -1193,10 +1277,7 @@ class Coder:
if args:
self.io.ai_output(json.dumps(args, indent=4))
if interrupted:
raise KeyboardInterrupt
self.calculate_and_show_tokens_and_cost(messages, completion)
self.calculate_and_show_tokens_and_cost(messages, completion)
def show_send_output(self, completion):
if self.verbose:
@ -1218,7 +1299,7 @@ class Coder:
show_func_err = func_err
try:
self.partial_response_content = completion.choices[0].message.content
self.partial_response_content = completion.choices[0].message.content or ""
except AttributeError as content_err:
show_content_err = content_err
@ -1312,13 +1393,19 @@ class Coder:
prompt_tokens = self.main_model.token_count(messages)
completion_tokens = self.main_model.token_count(self.partial_response_content)
self.usage_report = f"Tokens: {prompt_tokens:,} sent, {completion_tokens:,} received."
self.message_tokens_sent += prompt_tokens
self.message_tokens_received += completion_tokens
tokens_report = (
f"Tokens: {self.message_tokens_sent:,} sent, {self.message_tokens_received:,} received."
)
if self.main_model.info.get("input_cost_per_token"):
cost += prompt_tokens * self.main_model.info.get("input_cost_per_token")
if self.main_model.info.get("output_cost_per_token"):
cost += completion_tokens * self.main_model.info.get("output_cost_per_token")
self.total_cost += cost
self.message_cost += cost
def format_cost(value):
if value == 0:
@ -1329,13 +1416,24 @@ class Coder:
else:
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
self.usage_report += (
f" Cost: ${format_cost(cost)} request, ${format_cost(self.total_cost)} session."
cost_report = (
f" Cost: ${format_cost(self.message_cost)} message,"
f" ${format_cost(self.total_cost)} session."
)
self.usage_report = tokens_report + cost_report
else:
self.usage_report = tokens_report
def show_usage_report(self):
if self.usage_report:
self.io.tool_output(self.usage_report)
self.message_cost = 0.0
self.message_tokens_sent = 0
self.message_tokens_received = 0
def get_multi_response_content(self, final=False):
cur = self.multi_response_content
new = self.partial_response_content
cur = self.multi_response_content or ""
new = self.partial_response_content or ""
if new.rstrip() != new and not final:
new = new.rstrip()
@ -1377,7 +1475,10 @@ class Coder:
return max(path.stat().st_mtime for path in files)
def get_addable_relative_files(self):
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
all_files = set(self.get_all_relative_files())
inchat_files = set(self.get_inchat_relative_files())
read_only_files = set(self.get_rel_fname(fname) for fname in self.abs_read_only_fnames)
return all_files - inchat_files - read_only_files
def check_for_dirty_commit(self, path):
if not self.repo:
@ -1590,7 +1691,11 @@ class Coder:
if self.show_diffs:
self.commands.cmd_diff()
self.io.tool_output(f"You can use /undo to revert and discard commit {commit_hash}.")
def show_undo_hint(self):
if not self.commit_before_message:
return
if self.commit_before_message[-1] != self.repo.get_head():
self.io.tool_output("You can use /undo to undo and discard each aider commit.")
def dirty_commit(self):
if not self.need_commit_before_edits: