Added options for automatic linting and testing after changes.

This commit is contained in:
Paul Gauthier 2024-05-19 07:34:19 -07:00
parent 398a1300dc
commit 9ee332f5d9
5 changed files with 106 additions and 44 deletions

View file

@ -15,13 +15,6 @@ def get_parser(default_config_files, git_root):
auto_env_var_prefix="AIDER_", auto_env_var_prefix="AIDER_",
) )
group = parser.add_argument_group("Main") group = parser.add_argument_group("Main")
group = parser.add_argument_group("Main")
group.add_argument(
"--auto-lint",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable/disable automatic linting after changes (default: True)",
)
group.add_argument( group.add_argument(
"files", "files",
metavar="FILE", metavar="FILE",
@ -310,6 +303,7 @@ def get_parser(default_config_files, git_root):
default=False, default=False,
help="Perform a dry run without modifying files (default: False)", help="Perform a dry run without modifying files (default: False)",
) )
group = parser.add_argument_group("Fixing and committing")
group.add_argument( group.add_argument(
"--commit", "--commit",
action="store_true", action="store_true",
@ -319,16 +313,42 @@ def get_parser(default_config_files, git_root):
group.add_argument( group.add_argument(
"--lint", "--lint",
action="store_true", action="store_true",
help="Commit, run the linter on all dirty files, fix problems and commit again", help="Run the linter on all dirty files, fix problems and commit",
default=False, default=False,
) )
group.add_argument( group.add_argument(
"--lint-cmd", "--lint-cmd",
action="append", action="append",
help='Specify lint commands to run for different languages, eg: "python: flake8 --select=..." (can be used multiple times)', help=(
'Specify lint commands to run for different languages, eg: "python: flake8'
' --select=..." (can be used multiple times)'
),
default=[], default=[],
) )
group.add_argument(
"--auto-lint",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable/disable automatic linting after changes (default: True)",
)
group.add_argument(
"--test-cmd",
action="append",
help="Specify command to run tests",
default=[],
)
group.add_argument(
"--auto-test",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable/disable automatic testing after changes (default: False)",
)
group.add_argument(
"--test",
action="store_true",
help="Run tests and fix problems found",
default=False,
)
########## ##########
group = parser.add_argument_group("Other Settings") group = parser.add_argument_group("Other Settings")

View file

@ -59,6 +59,9 @@ class Coder:
max_reflections = 5 max_reflections = 5
edit_format = None edit_format = None
yield_stream = False yield_stream = False
auto_lint = True
auto_test = False
test_cmd = None
@classmethod @classmethod
def create( def create(
@ -195,6 +198,10 @@ class Coder:
done_messages=None, done_messages=None,
max_chat_history_tokens=None, max_chat_history_tokens=None,
restore_chat_history=False, restore_chat_history=False,
auto_lint=True,
auto_test=False,
lint_cmds=None,
test_cmd=None,
): ):
if not fnames: if not fnames:
fnames = [] fnames = []
@ -289,8 +296,6 @@ class Coder:
self.verbose, self.verbose,
) )
self.linter = Linter(root=self.root, encoding=io.encoding)
if max_chat_history_tokens is None: if max_chat_history_tokens is None:
max_chat_history_tokens = self.main_model.max_chat_history_tokens max_chat_history_tokens = self.main_model.max_chat_history_tokens
self.summarizer = ChatSummary( self.summarizer = ChatSummary(
@ -307,6 +312,14 @@ class Coder:
self.done_messages = utils.split_chat_history_markdown(history_md) self.done_messages = utils.split_chat_history_markdown(history_md)
self.summarize_start() self.summarize_start()
# Linting and testing
self.linter = Linter(root=self.root, encoding=io.encoding)
self.auto_lint = auto_lint
self.setup_lint_cmds(lint_cmds)
self.auto_test = auto_test
self.test_cmd = test_cmd
# validate the functions jsonschema # validate the functions jsonschema
if self.functions: if self.functions:
for function in self.functions: for function in self.functions:
@ -316,6 +329,22 @@ class Coder:
self.io.tool_output("JSON Schema:") self.io.tool_output("JSON Schema:")
self.io.tool_output(json.dumps(self.functions, indent=4)) self.io.tool_output(json.dumps(self.functions, indent=4))
def setup_lint_cmds(self, lint_cmds):
for lint_cmd in lint_cmds:
pieces = lint_cmd.split(":")
lang = pieces[0]
cmd = lint_cmd[len(lang) + 1 :]
lang = lang.strip()
cmd = cmd.strip()
if lang and cmd:
self.linter.set_linter(lang, cmd)
else:
self.io.tool_error(f'Unable to parse --lint-cmd "{lint_cmd}"')
self.io.tool_error(f'The arg should be "language: cmd --args ..."')
self.io.tool_error('For example: --lint-cmd "python: flake8 --select=E9"')
def show_announcements(self): def show_announcements(self):
for line in self.get_announcements(): for line in self.get_announcements():
self.io.tool_output(line) self.io.tool_output(line)
@ -737,13 +766,20 @@ class Coder:
self.update_cur_messages(set()) self.update_cur_messages(set())
return return
if edited: if edited and self.auto_lint:
lint_errors = self.lint_edited(edited) lint_errors = self.lint_edited(edited)
if lint_errors: if lint_errors:
self.reflected_message = lint_errors self.reflected_message = lint_errors
self.update_cur_messages(set()) self.update_cur_messages(set())
return return
if edited and self.auto_test:
test_errors = self.commands.cmd_test(self.test_cmd)
if test_errors:
self.reflected_message = test_errors
self.update_cur_messages(set())
return
self.update_cur_messages(edited) self.update_cur_messages(edited)
if edited: if edited:

View file

@ -168,7 +168,7 @@ class Commands:
linted = False linted = False
for fname in fnames: for fname in fnames:
try: try:
errors = self.coder.linter.lint(fname) errors = self.coder.linter.lint(fname, cmd=args)
linted = True linted = True
except FileNotFoundError as err: except FileNotFoundError as err:
self.io.tool_error(f"Unable to lint {fname}") self.io.tool_error(f"Unable to lint {fname}")

View file

@ -3,8 +3,8 @@ import subprocess
import sys import sys
import traceback import traceback
import warnings import warnings
from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from grep_ast import TreeContext, filename_to_lang from grep_ast import TreeContext, filename_to_lang
from tree_sitter_languages import get_parser # noqa: E402 from tree_sitter_languages import get_parser # noqa: E402
@ -51,19 +51,21 @@ class Linter:
filenames_linenums = find_filenames_and_linenums(errors, [rel_fname]) filenames_linenums = find_filenames_and_linenums(errors, [rel_fname])
if filenames_linenums: if filenames_linenums:
filename, linenums = next(iter(filenames_linenums.items())) filename, linenums = next(iter(filenames_linenums.items()))
linenums = [num-1 for num in linenums] linenums = [num - 1 for num in linenums]
return LintResult(text=res, lines=linenums) return LintResult(text=res, lines=linenums)
def lint(self, fname): def lint(self, fname, cmd=None):
lang = filename_to_lang(fname)
if not lang:
return
rel_fname = self.get_rel_fname(fname) rel_fname = self.get_rel_fname(fname)
code = Path(fname).read_text(self.encoding) code = Path(fname).read_text(self.encoding)
cmd = self.languages.get(lang) if cmd:
cmd = cmd.strip()
if not cmd:
lang = filename_to_lang(fname)
if not lang:
return
cmd = self.languages.get(lang)
if callable(cmd): if callable(cmd):
linkres = cmd(fname, rel_fname, code) linkres = cmd(fname, rel_fname, code)
@ -75,15 +77,15 @@ class Linter:
if not linkres: if not linkres:
return return
res = '# Fix any errors below\n\n' res = "# Fix any errors below, if possible.\n\n"
res += linkres.text res += linkres.text
res += '\n' res += "\n"
res += tree_context(fname, code, linkres.lines) res += tree_context(fname, code, linkres.lines)
return res return res
def py_lint(self, fname, rel_fname, code): def py_lint(self, fname, rel_fname, code):
result = '' result = ""
basic_res = basic_lint(rel_fname, code) basic_res = basic_lint(rel_fname, code)
compile_res = lint_python_compile(fname, code) compile_res = lint_python_compile(fname, code)
@ -95,19 +97,20 @@ class Linter:
except FileNotFoundError: except FileNotFoundError:
flake_res = None flake_res = None
text = '' text = ""
lines = set() lines = set()
for res in [basic_res, compile_res, flake_res]: for res in [basic_res, compile_res, flake_res]:
if not res: if not res:
continue continue
if text: if text:
text += '\n' text += "\n"
text += res.text text += res.text
lines.update(res.lines) lines.update(res.lines)
if text or lines: if text or lines:
return LintResult(text, lines) return LintResult(text, lines)
@dataclass @dataclass
class LintResult: class LintResult:
text: str text: str
@ -134,7 +137,7 @@ def lint_python_compile(fname, code):
tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :] tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :]
res = "".join(tb_lines) res = "".join(tb_lines)
return LintResult(text = res, lines = line_numbers) return LintResult(text=res, lines=line_numbers)
def basic_lint(fname, code): def basic_lint(fname, code):
@ -153,7 +156,7 @@ def basic_lint(fname, code):
if not errors: if not errors:
return return
return LintResult(text = '', lines = errors) return LintResult(text="", lines=errors)
def tree_context(fname, code, line_nums): def tree_context(fname, code, line_nums):
@ -193,23 +196,26 @@ def traverse_tree(node):
return errors return errors
import re import re
def find_filenames_and_linenums(text, fnames): def find_filenames_and_linenums(text, fnames):
""" """
Search text for all occurrences of <filename>:\d+ and make a list of them Search text for all occurrences of <filename>:\d+ and make a list of them
where <filename> is one of the filenames in the list `fnames`. where <filename> is one of the filenames in the list `fnames`.
""" """
pattern = re.compile(r'(\b(?:' + '|'.join(re.escape(fname) for fname in fnames) + r'):\d+\b)') pattern = re.compile(r"(\b(?:" + "|".join(re.escape(fname) for fname in fnames) + r"):\d+\b)")
matches = pattern.findall(text) matches = pattern.findall(text)
result = {} result = {}
for match in matches: for match in matches:
fname, linenum = match.rsplit(':', 1) fname, linenum = match.rsplit(":", 1)
if fname not in result: if fname not in result:
result[fname] = set() result[fname] = set()
result[fname].add(int(linenum)) result[fname].add(int(linenum))
return result return result
def main(): def main():
""" """
Main function to parse files provided as command line arguments. Main function to parse files provided as command line arguments.

View file

@ -332,6 +332,10 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
aider_ignore_file=args.aiderignore, aider_ignore_file=args.aiderignore,
max_chat_history_tokens=args.max_chat_history_tokens, max_chat_history_tokens=args.max_chat_history_tokens,
restore_chat_history=args.restore_chat_history, restore_chat_history=args.restore_chat_history,
auto_lint=args.auto_lint,
auto_test=args.auto_test,
lint_cmds=args.lint_cmd,
test_cmd=args.test_cmd,
) )
except ValueError as err: except ValueError as err:
@ -343,19 +347,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.show_announcements() coder.show_announcements()
for lint_cmd in args.lint_cmd:
pieces = lint_cmd.split(':')
lang = pieces[0]
cmd = lint_cmd[len(lang)+1:]
lang = lang.strip()
cmd = cmd.strip()
if lang and cmd:
coder.linter.set_linter(lang, cmd)
else:
io.tool_error(f"Unable to parse --lang-cmd {lang_cmd}")
if args.show_prompts: if args.show_prompts:
coder.cur_messages += [ coder.cur_messages += [
dict(role="user", content="Hello!"), dict(role="user", content="Hello!"),
@ -372,6 +363,15 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
coder.commands.cmd_lint("") coder.commands.cmd_lint("")
return return
if args.test:
if not args.test_cmd:
io.tool_error("No --test-cmd provided.")
return 1
test_errors = coder.commands.cmd_test(args.test_cmd)
if test_errors:
coder.run(test_errors)
return
if args.show_repo_map: if args.show_repo_map:
repo_map = coder.get_repo_map() repo_map = coder.get_repo_map()
if repo_map: if repo_map: