Merge branch 'main' into swe-bench

This commit is contained in:
Paul Gauthier 2024-05-19 15:20:46 -07:00
commit e758b01fb6
33 changed files with 997 additions and 177 deletions

View file

@ -14,8 +14,6 @@ def get_parser(default_config_files, git_root):
config_file_parser_class=configargparse.YAMLConfigFileParser,
auto_env_var_prefix="AIDER_",
)
##########
group = parser.add_argument_group("Main")
group.add_argument(
"files",
@ -305,12 +303,52 @@ def get_parser(default_config_files, git_root):
default=False,
help="Perform a dry run without modifying files (default: False)",
)
group = parser.add_argument_group("Fixing and committing")
group.add_argument(
"--commit",
action="store_true",
help="Commit all pending changes with a suitable commit message, then exit",
default=False,
)
group.add_argument(
"--lint",
action="store_true",
help="Run the linter on all dirty files, fix problems and commit",
default=False,
)
group.add_argument(
"--lint-cmd",
action="append",
help=(
'Specify lint commands to run for different languages, eg: "python: flake8'
' --select=..." (can be used multiple times)'
),
default=[],
)
group.add_argument(
"--auto-lint",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable/disable automatic linting after changes (default: True)",
)
group.add_argument(
"--test-cmd",
action="append",
help="Specify command to run tests",
default=[],
)
group.add_argument(
"--auto-test",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable/disable automatic testing after changes (default: False)",
)
group.add_argument(
"--test",
action="store_true",
help="Run tests and fix problems found",
default=False,
)
##########
group = parser.add_argument_group("Other Settings")

View file

@ -21,6 +21,7 @@ from aider import __version__, models, prompts, utils
from aider.commands import Commands
from aider.history import ChatSummary
from aider.io import InputOutput
from aider.linter import Linter
from aider.litellm import litellm
from aider.mdstream import MarkdownStream
from aider.repo import GitRepo
@ -55,10 +56,16 @@ class Coder:
num_exhausted_context_windows = 0
num_malformed_responses = 0
last_keyboard_interrupt = None
max_apply_update_errors = 3
num_reflections = 0
max_reflections = 3
edit_format = None
yield_stream = False
temperature = 0
auto_lint = True
auto_test = False
test_cmd = None
lint_outcome = None
test_outcome = None
@classmethod
def create(
@ -195,6 +202,10 @@ class Coder:
done_messages=None,
max_chat_history_tokens=None,
restore_chat_history=False,
auto_lint=True,
auto_test=False,
lint_cmds=None,
test_cmd=None,
):
if not fnames:
fnames = []
@ -305,6 +316,14 @@ class Coder:
self.done_messages = utils.split_chat_history_markdown(history_md)
self.summarize_start()
# Linting and testing
self.linter = Linter(root=self.root, encoding=io.encoding)
self.auto_lint = auto_lint
self.setup_lint_cmds(lint_cmds)
self.auto_test = auto_test
self.test_cmd = test_cmd
# validate the functions jsonschema
if self.functions:
for function in self.functions:
@ -314,6 +333,12 @@ class Coder:
self.io.tool_output("JSON Schema:")
self.io.tool_output(json.dumps(self.functions, indent=4))
def setup_lint_cmds(self, lint_cmds):
if not lint_cmds:
return
for lang, cmd in lint_cmds.items():
self.linter.set_linter(lang, cmd)
def show_announcements(self):
for line in self.get_announcements():
self.io.tool_output(line)
@ -524,12 +549,20 @@ class Coder:
def run_stream(self, user_message):
self.io.user_input(user_message)
self.reflected_message = None
self.init_before_message()
yield from self.send_new_user_message(user_message)
def init_before_message(self):
self.reflected_message = None
self.num_reflections = 0
self.lint_outcome = None
self.test_outcome = None
self.edit_outcome = None
def run(self, with_message=None):
while True:
self.num_malformed_responses = 0
self.init_before_message()
try:
if with_message:
new_user_message = with_message
@ -540,7 +573,14 @@ class Coder:
while new_user_message:
self.reflected_message = None
list(self.send_new_user_message(new_user_message))
new_user_message = self.reflected_message
if self.num_reflections < self.max_reflections:
self.num_reflections += 1
new_user_message = self.reflected_message
else:
self.io.tool_error(
f"Only {self.max_reflections} reflections allowed, stopping."
)
new_user_message = None
if with_message:
return self.partial_response_content
@ -761,10 +801,33 @@ class Coder:
self.cur_messages += [dict(role="assistant", content=content)]
return
edited, edit_error = self.apply_updates()
if edit_error:
edited = self.apply_updates()
if self.reflected_message:
self.edit_outcome = False
self.update_cur_messages(set())
self.reflected_message = edit_error
return
if edited:
self.edit_outcome = True
if edited and self.auto_lint:
lint_errors = self.lint_edited(edited)
self.lint_outcome = not lint_errors
if lint_errors:
ok = self.io.confirm_ask("Attempt to fix lint errors?")
if ok:
self.reflected_message = lint_errors
self.update_cur_messages(set())
return
if edited and self.auto_test:
test_errors = self.commands.cmd_test(self.test_cmd)
self.test_outcome = not test_errors
if test_errors:
ok = self.io.confirm_ask("Attempt to fix test errors?")
if ok:
self.reflected_message = test_errors
self.update_cur_messages(set())
return
self.update_cur_messages(edited)
@ -786,6 +849,20 @@ class Coder:
else:
self.reflected_message = add_rel_files_message
def lint_edited(self, fnames):
res = ""
for fname in fnames:
errors = self.linter.lint(self.abs_root_path(fname))
if errors:
res += "\n"
res += errors
res += "\n"
if res:
self.io.tool_error(res)
return res
def update_cur_messages(self, edited):
if self.partial_response_content:
self.cur_messages += [dict(role="assistant", content=self.partial_response_content)]
@ -1114,8 +1191,6 @@ class Coder:
)
self.warning_given = True
apply_update_errors = 0
def prepare_to_edit(self, edits):
res = []
seen = dict()
@ -1149,37 +1224,30 @@ class Coder:
edited = self.update_files()
except ValueError as err:
self.num_malformed_responses += 1
err = err.args[0]
self.apply_update_errors += 1
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, retrying...")
self.io.tool_error("https://aider.chat/docs/faq.html#aider-isnt-editing-my-files")
self.io.tool_error(str(err), strip=False)
return None, err
else:
self.io.tool_error(f"Malformed response #{self.apply_update_errors}, aborting.")
self.io.tool_error("https://aider.chat/docs/faq.html#aider-isnt-editing-my-files")
self.io.tool_error(str(err), strip=False)
return False, None
self.io.tool_error("The LLM did not conform to the edit format.")
self.io.tool_error(
"For more info see: https://aider.chat/docs/faq.html#aider-isnt-editing-my-files"
)
self.io.tool_error()
self.io.tool_error(str(err), strip=False)
self.reflected_message = str(err)
return
except git.exc.GitCommandError as err:
self.io.tool_error(str(err))
return False, None
return
except Exception as err:
print(err)
print()
traceback.print_exc()
self.apply_update_errors += 1
if self.apply_update_errors < self.max_apply_update_errors:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, retrying...")
self.io.tool_error(str(err), strip=False)
return None, str(err)
else:
self.io.tool_error(f"Update exception #{self.apply_update_errors}, aborting")
self.io.tool_error(str(err), strip=False)
return False, None
self.io.tool_error("Exception while updating files:")
self.io.tool_error(str(err), strip=False)
self.apply_update_errors = 0
traceback.print_exc()
self.reflected_message = str(err)
return
for path in edited:
if self.dry_run:
@ -1187,7 +1255,7 @@ class Coder:
else:
self.io.tool_output(f"Applied edit to {path}")
return edited, None
return edited
def parse_partial_args(self):
# dump(self.partial_response_function_call)

View file

@ -124,6 +124,9 @@ Every *SEARCH/REPLACE block* must use this format:
Every *SEARCH* section must *EXACTLY MATCH* the existing source code, character for character, including all comments, docstrings, etc.
*SEARCH/REPLACE* blocks will replace *all* matching occurrences.
Include enough lines to make the SEARCH blocks unique.
Include *ALL* the code being searched and replaced!
Only create *SEARCH/REPLACE* blocks for files that the user has added to the chat!

View file

@ -1,6 +1,6 @@
from pathlib import Path
from aider import diffs
from pathlib import Path
from ..dump import dump # noqa: F401
from .base_coder import Coder

View file

@ -153,7 +153,43 @@ class Commands:
commit_message = args.strip()
self.coder.repo.commit(message=commit_message)
def cmd_clear(self, args=""):
def cmd_lint(self, args):
"Commit, run the linter on all dirty files, fix problems and commit again"
if not self.coder.repo:
self.io.tool_error("No git repository found.")
return
if not self.coder.repo.is_dirty():
self.io.tool_error("No more changes to commit.")
return
fnames = self.coder.repo.get_dirty_files()
linted = False
for fname in fnames:
try:
errors = self.coder.linter.lint(fname, cmd=args)
linted = True
except FileNotFoundError as err:
self.io.tool_error(f"Unable to lint {fname}")
self.io.tool_error(str(err))
continue
if errors:
# Commit everything before we start fixing lint errors
if self.coder.repo.is_dirty():
self.cmd_commit("")
self.io.tool_error(errors)
abs_file_path = self.coder.abs_root_path(fname)
self.coder.abs_fnames.add(abs_file_path)
self.coder.run(errors)
if linted and self.coder.repo.is_dirty():
self.cmd_commit("")
def cmd_clear(self, args):
"Clear the chat history"
self.coder.done_messages = []

236
aider/linter.py Normal file
View file

@ -0,0 +1,236 @@
import os
import subprocess
import sys
import traceback
import warnings
from dataclasses import dataclass
from pathlib import Path
from grep_ast import TreeContext, filename_to_lang
from tree_sitter_languages import get_parser # noqa: E402
# tree_sitter is throwing a FutureWarning
warnings.simplefilter("ignore", category=FutureWarning)
class Linter:
def __init__(self, encoding="utf-8", root=None):
self.encoding = encoding
self.root = root
self.languages = dict(
python=self.py_lint,
)
def set_linter(self, lang, cmd):
self.languages[lang] = cmd
def get_rel_fname(self, fname):
if self.root:
return os.path.relpath(fname, self.root)
else:
return fname
def run_cmd(self, cmd, rel_fname, code):
cmd += " " + rel_fname
cmd = cmd.split()
process = subprocess.Popen(
cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
stdout, _ = process.communicate()
errors = stdout.decode()
if process.returncode == 0:
return # zero exit status
cmd = " ".join(cmd)
res = f"## Running: {cmd}\n\n"
res += errors
linenums = []
filenames_linenums = find_filenames_and_linenums(errors, [rel_fname])
if filenames_linenums:
filename, linenums = next(iter(filenames_linenums.items()))
linenums = [num - 1 for num in linenums]
return LintResult(text=res, lines=linenums)
def lint(self, fname, cmd=None):
rel_fname = self.get_rel_fname(fname)
code = Path(fname).read_text(self.encoding)
if cmd:
cmd = cmd.strip()
if not cmd:
lang = filename_to_lang(fname)
if not lang:
return
cmd = self.languages.get(lang)
if callable(cmd):
linkres = cmd(fname, rel_fname, code)
elif cmd:
linkres = self.run_cmd(cmd, rel_fname, code)
else:
linkres = basic_lint(rel_fname, code)
if not linkres:
return
res = "# Fix any errors below, if possible.\n\n"
res += linkres.text
res += "\n"
res += tree_context(fname, code, linkres.lines)
return res
def py_lint(self, fname, rel_fname, code):
result = ""
basic_res = basic_lint(rel_fname, code)
compile_res = lint_python_compile(fname, code)
fatal = "E9,F821,F823,F831,F406,F407,F701,F702,F704,F706"
flake8 = f"flake8 --select={fatal} --show-source"
try:
flake_res = self.run_cmd(flake8, rel_fname, code)
except FileNotFoundError:
flake_res = None
text = ""
lines = set()
for res in [basic_res, compile_res, flake_res]:
if not res:
continue
if text:
text += "\n"
text += res.text
lines.update(res.lines)
if text or lines:
return LintResult(text, lines)
@dataclass
class LintResult:
text: str
lines: list
def lint_python_compile(fname, code):
try:
compile(code, fname, "exec") # USE TRACEBACK BELOW HERE
return
except Exception as err:
line_numbers = list(range(err.lineno - 1, err.end_lineno))
tb_lines = traceback.format_exception(type(err), err, err.__traceback__)
last_file_i = 0
target = "# USE TRACEBACK"
target += " BELOW HERE"
for i in range(len(tb_lines)):
if target in tb_lines[i]:
last_file_i = i
break
tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :]
res = "".join(tb_lines)
return LintResult(text=res, lines=line_numbers)
def basic_lint(fname, code):
"""
Use tree-sitter to look for syntax errors, display them with tree context.
"""
lang = filename_to_lang(fname)
if not lang:
return
parser = get_parser(lang)
tree = parser.parse(bytes(code, "utf-8"))
errors = traverse_tree(tree.root_node)
if not errors:
return
return LintResult(text="", lines=errors)
def tree_context(fname, code, line_nums):
context = TreeContext(
fname,
code,
color=False,
line_number=True,
child_context=False,
last_line=False,
margin=0,
mark_lois=True,
loi_pad=3,
# header_max=30,
show_top_of_file_parent_scope=False,
)
line_nums = set(line_nums)
context.add_lines_of_interest(line_nums)
context.add_context()
s = "s" if len(line_nums) > 1 else ""
output = f"## See relevant line{s} below marked with █.\n\n"
output += fname + ":\n"
output += context.format()
return output
# Traverse the tree to find errors
def traverse_tree(node):
errors = []
if node.type == "ERROR" or node.is_missing:
line_no = node.start_point[0]
errors.append(line_no)
for child in node.children:
errors += traverse_tree(child)
return errors
import re
def find_filenames_and_linenums(text, fnames):
"""
Search text for all occurrences of <filename>:\d+ and make a list of them
where <filename> is one of the filenames in the list `fnames`.
"""
pattern = re.compile(r"(\b(?:" + "|".join(re.escape(fname) for fname in fnames) + r"):\d+\b)")
matches = pattern.findall(text)
result = {}
for match in matches:
fname, linenum = match.rsplit(":", 1)
if fname not in result:
result[fname] = set()
result[fname].add(int(linenum))
return result
def main():
"""
Main function to parse files provided as command line arguments.
"""
if len(sys.argv) < 2:
print("Usage: python linter.py <file1> <file2> ...")
sys.exit(1)
linter = Linter(root=os.getcwd())
for file_path in sys.argv[1:]:
errors = linter.lint(file_path)
if errors:
print(errors)
if __name__ == "__main__":
main()

View file

@ -8,4 +8,6 @@ os.environ["OR_APP_NAME"] = "Aider"
import litellm # noqa: E402
litellm.suppress_debug_info = True
__all__ = [litellm]

View file

@ -177,6 +177,29 @@ def launch_gui(args):
# sys.argv = ['streamlit', 'run', '--'] + args
def parse_lint_cmds(lint_cmds, io):
err = False
res = dict()
for lint_cmd in lint_cmds:
pieces = lint_cmd.split(":")
lang = pieces[0]
cmd = lint_cmd[len(lang) + 1 :]
lang = lang.strip()
cmd = cmd.strip()
if lang and cmd:
res[lang] = cmd
else:
io.tool_error(f'Unable to parse --lint-cmd "{lint_cmd}"')
io.tool_error('The arg should be "language: cmd --args ..."')
io.tool_error('For example: --lint-cmd "python: flake8 --select=E9"')
err = True
if err:
return
return res
def main(argv=None, input=None, output=None, force_git_root=None, return_coder=False):
if argv is None:
argv = sys.argv[1:]
@ -306,6 +329,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
main_model = models.Model(args.model, weak_model=args.weak_model)
lint_cmds = parse_lint_cmds(args.lint_cmd, io)
if lint_cmds is None:
return 1
if args.show_model_warnings:
models.sanity_check_models(io, main_model)
@ -332,6 +359,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
aider_ignore_file=args.aiderignore,
max_chat_history_tokens=args.max_chat_history_tokens,
restore_chat_history=args.restore_chat_history,
auto_lint=args.auto_lint,
auto_test=args.auto_test,
lint_cmds=lint_cmds,
test_cmd=args.test_cmd,
)
except ValueError as err:
@ -355,6 +386,19 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.commands.cmd_commit("")
return
if args.lint:
coder.commands.cmd_lint("")
return
if args.test:
if not args.test_cmd:
io.tool_error("No --test-cmd provided.")
return 1
test_errors = coder.commands.cmd_test(args.test_cmd)
if test_errors:
coder.run(test_errors)
return
if args.show_repo_map:
repo_map = coder.get_repo_map()
if repo_map:

View file

@ -242,6 +242,23 @@ class GitRepo:
res = Path(self.root) / path
return utils.safe_abs_path(res)
def get_dirty_files(self):
"""
Returns a list of all files which are dirty (not committed), either staged or in the working
directory.
"""
dirty_files = set()
# Get staged files
staged_files = self.repo.git.diff("--name-only", "--cached").splitlines()
dirty_files.update(staged_files)
# Get unstaged files
unstaged_files = self.repo.git.diff("--name-only").splitlines()
dirty_files.update(unstaged_files)
return list(dirty_files)
def is_dirty(self, path=None):
if path and not self.path_in_repo(path):
return True