Merge branch 'main' into swe-bench

This commit is contained in:
Paul Gauthier 2024-05-19 15:20:46 -07:00
commit e758b01fb6
33 changed files with 997 additions and 177 deletions

View file

@ -21,6 +21,7 @@ from aider import __version__, models, prompts, utils
from aider.commands import Commands
from aider.history import ChatSummary
from aider.io import InputOutput
from aider.linter import Linter
from aider.litellm import litellm
from aider.mdstream import MarkdownStream
from aider.repo import GitRepo
@ -55,10 +56,16 @@ class Coder:
num_exhausted_context_windows = 0
num_malformed_responses = 0
last_keyboard_interrupt = None
max_apply_update_errors = 3
num_reflections = 0
max_reflections = 3
edit_format = None
yield_stream = False
temperature = 0
auto_lint = True
auto_test = False
test_cmd = None
lint_outcome = None
test_outcome = None
@classmethod
def create(
@ -195,6 +202,10 @@ class Coder:
done_messages=None,
max_chat_history_tokens=None,
restore_chat_history=False,
auto_lint=True,
auto_test=False,
lint_cmds=None,
test_cmd=None,
):
if not fnames:
fnames = []
@ -305,6 +316,14 @@ class Coder:
self.done_messages = utils.split_chat_history_markdown(history_md)
self.summarize_start()
# Linting and testing
self.linter = Linter(root=self.root, encoding=io.encoding)
self.auto_lint = auto_lint
self.setup_lint_cmds(lint_cmds)
self.auto_test = auto_test
self.test_cmd = test_cmd
# validate the functions jsonschema
if self.functions:
for function in self.functions:
@ -314,6 +333,12 @@ class Coder:
self.io.tool_output("JSON Schema:")
self.io.tool_output(json.dumps(self.functions, indent=4))
def setup_lint_cmds(self, lint_cmds):
if not lint_cmds:
return
for lang, cmd in lint_cmds.items():
self.linter.set_linter(lang, cmd)
def show_announcements(self):
for line in self.get_announcements():
self.io.tool_output(line)
@ -524,12 +549,20 @@ class Coder:
def run_stream(self, user_message):
self.io.user_input(user_message)
self.reflected_message = None
self.init_before_message()
yield from self.send_new_user_message(user_message)
def init_before_message(self):
self.reflected_message = None
self.num_reflections = 0
self.lint_outcome = None
self.test_outcome = None
self.edit_outcome = None
def run(self, with_message=None):
while True:
self.num_malformed_responses = 0
self.init_before_message()
try:
if with_message:
new_user_message = with_message
@ -540,7 +573,14 @@ class Coder:
while new_user_message:
self.reflected_message = None
list(self.send_new_user_message(new_user_message))
new_user_message = self.reflected_message
if self.num_reflections < self.max_reflections:
self.num_reflections += 1
new_user_message = self.reflected_message
else:
self.io.tool_error(
f"Only {self.max_reflections} reflections allowed, stopping."
)
new_user_message = None
if with_message:
return self.partial_response_content
@ -761,10 +801,33 @@ class Coder:
self.cur_messages += [dict(role="assistant", content=content)]
return
edited, edit_error = self.apply_updates()
if edit_error:
edited = self.apply_updates()
if self.reflected_message:
self.edit_outcome = False
self.update_cur_messages(set())
self.reflected_message = edit_error
return
if edited:
self.edit_outcome = True
if edited and self.auto_lint:
lint_errors = self.lint_edited(edited)
self.lint_outcome = not lint_errors
if lint_errors:
ok = self.io.confirm_ask("Attempt to fix lint errors?")
if ok:
self.reflected_message = lint_errors
self.update_cur_messages(set())
return
if edited and self.auto_test:
test_errors = self.commands.cmd_test(self.test_cmd)
self.test_outcome = not test_errors
if test_errors:
ok = self.io.confirm_ask("Attempt to fix test errors?")
if ok:
self.reflected_message = test_errors
self.update_cur_messages(set())
return
self.update_cur_messages(edited)
@ -786,6 +849,20 @@ class Coder:
else:
self.reflected_message = add_rel_files_message
def lint_edited(self, fnames):
res = ""
for fname in fnames:
errors = self.linter.lint(self.abs_root_path(fname))
if errors:
res += "\n"
res += errors
res += "\n"
if res:
self.io.tool_error(res)
return res
def update_cur_messages(self, edited):
if self.partial_response_content:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
@ -1114,8 +1191,6 @@ class Coder:
)
self.warning_given = True
apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
seen = dict()
@ -1149,37 +1224,30 @@ class Coder:
edited = self.update_files()
except ValueError as err:
self.num_malformed_responses += 1
err = err.args[0]
self.apply_update_errors += 1
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, retrying...")
self.io.tool_error("https://aider.chat/docs/faq.html#aider-isnt-editing-my-files")
self.io.tool_error(str(err), strip=False)
return None, err
else:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, aborting.")
self.io.tool_error("https://aider.chat/docs/faq.html#aider-isnt-editing-my-files")
self.io.tool_error(str(err), strip=False)
return False, None
self.io.tool_error("The LLM did not conform to the edit format.")
self.io.tool_error(
"For more info see: https://aider.chat/docs/faq.html#aider-isnt-editing-my-files"
)
self.io.tool_error()
self.io.tool_error(str(err), strip=False)
self.reflected_message = str(err)
return
except git.exc.GitCommandError as err:
self.io.tool_error(str(err))
return False, None
return
except Exception as err:
print(err)
print()
traceback.print_exc()
self.apply_update_errors += 1
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err), strip=False)
return None, str(err)
else:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, aborting")
self.io.tool_error(str(err), strip=False)
return False, None
self.io.tool_error("Exception while updating files:")
self.io.tool_error(str(err), strip=False)
self.apply_update_errors = 0
traceback.print_exc()
self.reflected_message = str(err)
return
for path in edited:
if self.dry_run:
@ -1187,7 +1255,7 @@ class Coder:
else:
self.io.tool_output(f"Applied edit to {path}")
return edited, None
return edited
def parse_partial_args(self):
# dump(self.partial_response_function_call)