From 9ee332f5d9adfb05292b33f2d65c15d1aed7e012 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Sun, 19 May 2024 07:34:19 -0700 Subject: [PATCH] Added options for automatic linting and testing after changes. --- aider/args.py | 40 +++++++++++++++++++++++++++--------- aider/coders/base_coder.py | 42 +++++++++++++++++++++++++++++++++++--- aider/commands.py | 2 +- aider/linter.py | 40 +++++++++++++++++++++--------------- aider/main.py | 26 +++++++++++------------ 5 files changed, 106 insertions(+), 44 deletions(-) diff --git a/aider/args.py b/aider/args.py index 1689621a2..a528c78d0 100644 --- a/aider/args.py +++ b/aider/args.py @@ -15,13 +15,6 @@ def get_parser(default_config_files, git_root): auto_env_var_prefix="AIDER_", ) group = parser.add_argument_group("Main") - group = parser.add_argument_group("Main") - group.add_argument( - "--auto-lint", - action=argparse.BooleanOptionalAction, - default=True, - help="Enable/disable automatic linting after changes (default: True)", - ) group.add_argument( "files", metavar="FILE", @@ -310,6 +303,7 @@ def get_parser(default_config_files, git_root): default=False, help="Perform a dry run without modifying files (default: False)", ) + group = parser.add_argument_group("Fixing and committing") group.add_argument( "--commit", action="store_true", @@ -319,16 +313,42 @@ def get_parser(default_config_files, git_root): group.add_argument( "--lint", action="store_true", - help="Commit, run the linter on all dirty files, fix problems and commit again", + help="Run the linter on all dirty files, fix problems and commit", default=False, ) group.add_argument( "--lint-cmd", action="append", - help='Specify lint commands to run for different languages, eg: "python: flake8 --select=..." (can be used multiple times)', + help=( + 'Specify lint commands to run for different languages, eg: "python: flake8' + ' --select=..." (can be used multiple times)' + ), default=[], ) - + group.add_argument( + "--auto-lint", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable/disable automatic linting after changes (default: True)", + ) + group.add_argument( + "--test-cmd", + action="append", + help="Specify command to run tests", + default=[], + ) + group.add_argument( + "--auto-test", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable/disable automatic testing after changes (default: False)", + ) + group.add_argument( + "--test", + action="store_true", + help="Run tests and fix problems found", + default=False, + ) ########## group = parser.add_argument_group("Other Settings") diff --git a/aider/coders/base_coder.py b/aider/coders/base_coder.py index ccf7854f3..6a1991222 100755 --- a/aider/coders/base_coder.py +++ b/aider/coders/base_coder.py @@ -59,6 +59,9 @@ class Coder: max_reflections = 5 edit_format = None yield_stream = False + auto_lint = True + auto_test = False + test_cmd = None @classmethod def create( @@ -195,6 +198,10 @@ class Coder: done_messages=None, max_chat_history_tokens=None, restore_chat_history=False, + auto_lint=True, + auto_test=False, + lint_cmds=None, + test_cmd=None, ): if not fnames: fnames = [] @@ -289,8 +296,6 @@ class Coder: self.verbose, ) - self.linter = Linter(root=self.root, encoding=io.encoding) - if max_chat_history_tokens is None: max_chat_history_tokens = self.main_model.max_chat_history_tokens self.summarizer = ChatSummary( @@ -307,6 +312,14 @@ class Coder: self.done_messages = utils.split_chat_history_markdown(history_md) self.summarize_start() + # Linting and testing + self.linter = Linter(root=self.root, encoding=io.encoding) + self.auto_lint = auto_lint + self.setup_lint_cmds(lint_cmds) + + self.auto_test = auto_test + self.test_cmd = test_cmd + # validate the functions jsonschema if self.functions: for function in self.functions: @@ -316,6 +329,22 @@ class Coder: self.io.tool_output("JSON Schema:") self.io.tool_output(json.dumps(self.functions, indent=4)) + def setup_lint_cmds(self, lint_cmds): + for lint_cmd in lint_cmds: + pieces = lint_cmd.split(":") + lang = pieces[0] + cmd = lint_cmd[len(lang) + 1 :] + + lang = lang.strip() + cmd = cmd.strip() + + if lang and cmd: + self.linter.set_linter(lang, cmd) + else: + self.io.tool_error(f'Unable to parse --lint-cmd "{lint_cmd}"') + self.io.tool_error(f'The arg should be "language: cmd --args ..."') + self.io.tool_error('For example: --lint-cmd "python: flake8 --select=E9"') + def show_announcements(self): for line in self.get_announcements(): self.io.tool_output(line) @@ -737,13 +766,20 @@ class Coder: self.update_cur_messages(set()) return - if edited: + if edited and self.auto_lint: lint_errors = self.lint_edited(edited) if lint_errors: self.reflected_message = lint_errors self.update_cur_messages(set()) return + if edited and self.auto_test: + test_errors = self.commands.cmd_test(self.test_cmd) + if test_errors: + self.reflected_message = test_errors + self.update_cur_messages(set()) + return + self.update_cur_messages(edited) if edited: diff --git a/aider/commands.py b/aider/commands.py index 076d5dcf2..95f9b67b8 100644 --- a/aider/commands.py +++ b/aider/commands.py @@ -168,7 +168,7 @@ class Commands: linted = False for fname in fnames: try: - errors = self.coder.linter.lint(fname) + errors = self.coder.linter.lint(fname, cmd=args) linted = True except FileNotFoundError as err: self.io.tool_error(f"Unable to lint {fname}") diff --git a/aider/linter.py b/aider/linter.py index 95f436c4f..a18ff963d 100644 --- a/aider/linter.py +++ b/aider/linter.py @@ -3,8 +3,8 @@ import subprocess import sys import traceback import warnings -from pathlib import Path from dataclasses import dataclass +from pathlib import Path from grep_ast import TreeContext, filename_to_lang from tree_sitter_languages import get_parser # noqa: E402 @@ -51,19 +51,21 @@ class Linter: filenames_linenums = find_filenames_and_linenums(errors, [rel_fname]) if filenames_linenums: filename, linenums = next(iter(filenames_linenums.items())) - linenums = [num-1 for num in linenums] + linenums = [num - 1 for num in linenums] return LintResult(text=res, lines=linenums) - def lint(self, fname): - lang = filename_to_lang(fname) - if not lang: - return - + def lint(self, fname, cmd=None): rel_fname = self.get_rel_fname(fname) code = Path(fname).read_text(self.encoding) - cmd = self.languages.get(lang) + if cmd: + cmd = cmd.strip() + if not cmd: + lang = filename_to_lang(fname) + if not lang: + return + cmd = self.languages.get(lang) if callable(cmd): linkres = cmd(fname, rel_fname, code) @@ -75,15 +77,15 @@ class Linter: if not linkres: return - res = '# Fix any errors below\n\n' + res = "# Fix any errors below, if possible.\n\n" res += linkres.text - res += '\n' + res += "\n" res += tree_context(fname, code, linkres.lines) return res def py_lint(self, fname, rel_fname, code): - result = '' + result = "" basic_res = basic_lint(rel_fname, code) compile_res = lint_python_compile(fname, code) @@ -95,19 +97,20 @@ class Linter: except FileNotFoundError: flake_res = None - text = '' + text = "" lines = set() for res in [basic_res, compile_res, flake_res]: if not res: continue if text: - text += '\n' + text += "\n" text += res.text lines.update(res.lines) if text or lines: return LintResult(text, lines) + @dataclass class LintResult: text: str @@ -134,7 +137,7 @@ def lint_python_compile(fname, code): tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :] res = "".join(tb_lines) - return LintResult(text = res, lines = line_numbers) + return LintResult(text=res, lines=line_numbers) def basic_lint(fname, code): @@ -153,7 +156,7 @@ def basic_lint(fname, code): if not errors: return - return LintResult(text = '', lines = errors) + return LintResult(text="", lines=errors) def tree_context(fname, code, line_nums): @@ -193,23 +196,26 @@ def traverse_tree(node): return errors + import re + def find_filenames_and_linenums(text, fnames): """ Search text for all occurrences of :\d+ and make a list of them where is one of the filenames in the list `fnames`. """ - pattern = re.compile(r'(\b(?:' + '|'.join(re.escape(fname) for fname in fnames) + r'):\d+\b)') + pattern = re.compile(r"(\b(?:" + "|".join(re.escape(fname) for fname in fnames) + r"):\d+\b)") matches = pattern.findall(text) result = {} for match in matches: - fname, linenum = match.rsplit(':', 1) + fname, linenum = match.rsplit(":", 1) if fname not in result: result[fname] = set() result[fname].add(int(linenum)) return result + def main(): """ Main function to parse files provided as command line arguments. diff --git a/aider/main.py b/aider/main.py index 473c2c0b1..e09275209 100644 --- a/aider/main.py +++ b/aider/main.py @@ -332,6 +332,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F aider_ignore_file=args.aiderignore, max_chat_history_tokens=args.max_chat_history_tokens, restore_chat_history=args.restore_chat_history, + auto_lint=args.auto_lint, + auto_test=args.auto_test, + lint_cmds=args.lint_cmd, + test_cmd=args.test_cmd, ) except ValueError as err: @@ -343,19 +347,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F coder.show_announcements() - for lint_cmd in args.lint_cmd: - pieces = lint_cmd.split(':') - lang = pieces[0] - cmd = lint_cmd[len(lang)+1:] - - lang = lang.strip() - cmd = cmd.strip() - - if lang and cmd: - coder.linter.set_linter(lang, cmd) - else: - io.tool_error(f"Unable to parse --lang-cmd {lang_cmd}") - if args.show_prompts: coder.cur_messages += [ dict(role="user", content="Hello!"), @@ -372,6 +363,15 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F coder.commands.cmd_lint("") return + if args.test: + if not args.test_cmd: + io.tool_error("No --test-cmd provided.") + return 1 + test_errors = coder.commands.cmd_test(args.test_cmd) + if test_errors: + coder.run(test_errors) + return + if args.show_repo_map: repo_map = coder.get_repo_map() if repo_map: