aider/aider/linter.py
2024-06-21 17:43:14 -07:00

277 lines
7.2 KiB
Python

import io
import os
import re
import runpy
import subprocess
import sys
import traceback
import warnings
from dataclasses import dataclass
from pathlib import Path
from grep_ast import TreeContext, filename_to_lang
from tree_sitter_languages import get_parser # noqa: E402
from aider.dump import dump # noqa: F401
# tree_sitter is throwing a FutureWarning
warnings.simplefilter("ignore", category=FutureWarning)
class Linter:
def __init__(self, encoding="utf-8", root=None):
self.encoding = encoding
self.root = root
self.languages = dict(
python=self.py_lint,
)
self.all_lint_cmd = None
def set_linter(self, lang, cmd):
if lang:
self.languages[lang] = cmd
return
self.all_lint_cmd = cmd
def get_rel_fname(self, fname):
if self.root:
return os.path.relpath(fname, self.root)
else:
return fname
def run_cmd(self, cmd, rel_fname, code):
cmd += " " + rel_fname
cmd = cmd.split()
process = subprocess.Popen(
cmd, cwd=self.root, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
stdout, _ = process.communicate()
errors = stdout.decode()
if process.returncode == 0:
return # zero exit status
cmd = " ".join(cmd)
res = f"## Running: {cmd}\n\n"
res += errors
return self.errors_to_lint_result(rel_fname, res)
def errors_to_lint_result(self, rel_fname, errors):
if not errors:
return
linenums = []
filenames_linenums = find_filenames_and_linenums(errors, [rel_fname])
if filenames_linenums:
filename, linenums = next(iter(filenames_linenums.items()))
linenums = [num - 1 for num in linenums]
return LintResult(text=errors, lines=linenums)
def lint(self, fname, cmd=None):
rel_fname = self.get_rel_fname(fname)
code = Path(fname).read_text(self.encoding)
if cmd:
cmd = cmd.strip()
if not cmd:
lang = filename_to_lang(fname)
if not lang:
return
if self.all_lint_cmd:
cmd = self.all_lint_cmd
else:
cmd = self.languages.get(lang)
if callable(cmd):
linkres = cmd(fname, rel_fname, code)
elif cmd:
linkres = self.run_cmd(cmd, rel_fname, code)
else:
linkres = basic_lint(rel_fname, code)
if not linkres:
return
res = "# Fix any errors below, if possible.\n\n"
res += linkres.text
res += "\n"
res += tree_context(rel_fname, code, linkres.lines)
return res
def py_lint(self, fname, rel_fname, code):
basic_res = basic_lint(rel_fname, code)
compile_res = lint_python_compile(fname, code)
flake_res = self.flake8_lint(rel_fname)
text = ""
lines = set()
for res in [basic_res, compile_res, flake_res]:
if not res:
continue
if text:
text += "\n"
text += res.text
lines.update(res.lines)
if text or lines:
return LintResult(text, lines)
def flake8_lint(self, rel_fname):
fatal = "E9,F821,F823,F831,F406,F407,F701,F702,F704,F706"
flake8 = f"flake8 --select={fatal} --show-source --isolated"
original_argv = sys.argv
original_stdout = sys.stdout
sys.argv = flake8.split() + [rel_fname]
sys.stdout = io.TextIOWrapper(io.BytesIO(), encoding="utf-8")
text = f"## Running: {' '.join(sys.argv)}\n\n"
try:
runpy.run_module("flake8", run_name="__main__")
except SystemExit as e:
if e.code == 0:
errors = None
else:
sys.stdout.seek(0)
errors = sys.stdout.read()
sys.stdout = original_stdout
sys.argv = original_argv
finally:
sys.stdout = original_stdout
sys.argv = original_argv
if not errors:
return
text += errors
return self.errors_to_lint_result(rel_fname, text)
@dataclass
class LintResult:
text: str
lines: list
def lint_python_compile(fname, code):
try:
compile(code, fname, "exec") # USE TRACEBACK BELOW HERE
return
except Exception as err:
end_lineno = getattr(err, "end_lineno", err.lineno)
line_numbers = list(range(err.lineno - 1, end_lineno))
tb_lines = traceback.format_exception(type(err), err, err.__traceback__)
last_file_i = 0
target = "# USE TRACEBACK"
target += " BELOW HERE"
for i in range(len(tb_lines)):
if target in tb_lines[i]:
last_file_i = i
break
tb_lines = tb_lines[:1] + tb_lines[last_file_i + 1 :]
res = "".join(tb_lines)
return LintResult(text=res, lines=line_numbers)
def basic_lint(fname, code):
"""
Use tree-sitter to look for syntax errors, display them with tree context.
"""
lang = filename_to_lang(fname)
if not lang:
return
parser = get_parser(lang)
tree = parser.parse(bytes(code, "utf-8"))
errors = traverse_tree(tree.root_node)
if not errors:
return
return LintResult(text="", lines=errors)
def tree_context(fname, code, line_nums):
context = TreeContext(
fname,
code,
color=False,
line_number=True,
child_context=False,
last_line=False,
margin=0,
mark_lois=True,
loi_pad=3,
# header_max=30,
show_top_of_file_parent_scope=False,
)
line_nums = set(line_nums)
context.add_lines_of_interest(line_nums)
context.add_context()
s = "s" if len(line_nums) > 1 else ""
output = f"## See relevant line{s} below marked with █.\n\n"
output += fname + ":\n"
output += context.format()
return output
# Traverse the tree to find errors
def traverse_tree(node):
errors = []
if node.type == "ERROR" or node.is_missing:
line_no = node.start_point[0]
errors.append(line_no)
for child in node.children:
errors += traverse_tree(child)
return errors
def find_filenames_and_linenums(text, fnames):
"""
Search text for all occurrences of <filename>:\\d+ and make a list of them
where <filename> is one of the filenames in the list `fnames`.
"""
pattern = re.compile(r"(\b(?:" + "|".join(re.escape(fname) for fname in fnames) + r"):\d+\b)")
matches = pattern.findall(text)
result = {}
for match in matches:
fname, linenum = match.rsplit(":", 1)
if fname not in result:
result[fname] = set()
result[fname].add(int(linenum))
return result
def main():
"""
Main function to parse files provided as command line arguments.
"""
if len(sys.argv) < 2:
print("Usage: python linter.py <file1> <file2> ...")
sys.exit(1)
linter = Linter(root=os.getcwd())
for file_path in sys.argv[1:]:
errors = linter.lint(file_path)
if errors:
print(errors)
if __name__ == "__main__":
main()