Display syntax errors with tree context

This commit is contained in:
Paul Gauthier 2024-05-17 14:22:03 -07:00
parent 86fdeb0597
commit cb8a487c89

View file

@ -2,46 +2,61 @@ import os
import tree_sitter import tree_sitter
import sys import sys
import warnings import warnings
from pathlib import Path
from aider.dump import dump
from grep_ast import TreeContext, filename_to_lang
# tree_sitter is throwing a FutureWarning # tree_sitter is throwing a FutureWarning
warnings.simplefilter("ignore", category=FutureWarning) warnings.simplefilter("ignore", category=FutureWarning)
from tree_sitter_languages import get_language, get_parser # noqa: E402 from tree_sitter_languages import get_language, get_parser # noqa: E402
def parse_file_for_errors(file_path): def basic_lint(fname, code):
lang = "python" lang = filename_to_lang(fname)
language = get_language(lang) language = get_language(lang)
parser = get_parser(lang) parser = get_parser(lang)
# Read the file content tree = parser.parse(bytes(code, "utf-8"))
with open(file_path, 'r') as file:
content = file.read()
tree = parser.parse(bytes(content, "utf8")) errors = traverse_tree(tree.root_node)
if not errors:
return
# Traverse the tree to find errors and print context context = TreeContext(
def traverse_tree(node): fname,
if node.type == 'ERROR' or node.is_missing: code,
error_type = 'Syntax error' if node.type == 'ERROR' else 'Missing element' color=False,
start_line = max(0, node.start_point[0] - 3) line_number=False,
end_line = node.end_point[0] + 3 child_context=False,
error_line = node.start_point[0] + 1 last_line=False,
margin=0,
mark_lois=True,
loi_pad=5,
# header_max=30,
show_top_of_file_parent_scope=False,
)
context.add_lines_of_interest(errors)
context.add_context()
output = "# Syntax Errors found on the lines marked with █\n"
output += fname + ":\n"
output += context.format()
with open(file_path, 'r') as file: return output
lines = file.readlines()
print(f"{error_type} at line: {error_line}") # Traverse the tree to find errors and print context
print("Context:") def traverse_tree(node):
for i in range(start_line, min(end_line, len(lines))): errors = []
line_number = i + 1 if node.type == 'ERROR' or node.is_missing:
prefix = ">> " if line_number == error_line else " " line_no = node.start_point[0]
print(f"{prefix}{line_number}: {lines[i].rstrip()}") errors.append(line_no)
print("\n")
for child in node.children:
traverse_tree(child)
traverse_tree(tree.root_node) for child in node.children:
errors += traverse_tree(child)
return errors
def main(): def main():
""" """
@ -52,8 +67,10 @@ def main():
sys.exit(1) sys.exit(1)
for file_path in sys.argv[1:]: for file_path in sys.argv[1:]:
print(f"Checking file: {file_path}") code = Path(file_path).read_text()
parse_file_for_errors(file_path) errors = basic_lint(file_path, code)
if errors:
print(errors)
if __name__ == "__main__": if __name__ == "__main__":
main() main()