Display syntax errors with tree context

This commit is contained in:
Paul Gauthier 2024-05-17 14:22:03 -07:00
parent 86fdeb0597
commit cb8a487c89

View file

@ -2,46 +2,61 @@ import os
import tree_sitter
import sys
import warnings
from pathlib import Path
from aider.dump import dump
from grep_ast import TreeContext, filename_to_lang
# tree_sitter is throwing a FutureWarning
warnings.simplefilter("ignore", category=FutureWarning)
from tree_sitter_languages import get_language, get_parser # noqa: E402
def parse_file_for_errors(file_path):
def basic_lint(fname, code):
lang = "python"
lang = filename_to_lang(fname)
language = get_language(lang)
parser = get_parser(lang)
# Read the file content
with open(file_path, 'r') as file:
content = file.read()
tree = parser.parse(bytes(code, "utf-8"))
tree = parser.parse(bytes(content, "utf8"))
errors = traverse_tree(tree.root_node)
if not errors:
return
# Traverse the tree to find errors and print context
def traverse_tree(node):
if node.type == 'ERROR' or node.is_missing:
error_type = 'Syntax error' if node.type == 'ERROR' else 'Missing element'
start_line = max(0, node.start_point[0] - 3)
end_line = node.end_point[0] + 3
error_line = node.start_point[0] + 1
context = TreeContext(
fname,
code,
color=False,
line_number=False,
child_context=False,
last_line=False,
margin=0,
mark_lois=True,
loi_pad=5,
# header_max=30,
show_top_of_file_parent_scope=False,
)
context.add_lines_of_interest(errors)
context.add_context()
output = "# Syntax Errors found on the lines marked with █\n"
output += fname + ":\n"
output += context.format()
with open(file_path, 'r') as file:
lines = file.readlines()
return output
print(f"{error_type} at line: {error_line}")
print("Context:")
for i in range(start_line, min(end_line, len(lines))):
line_number = i + 1
prefix = ">> " if line_number == error_line else " "
print(f"{prefix}{line_number}: {lines[i].rstrip()}")
print("\n")
for child in node.children:
traverse_tree(child)
# Traverse the tree to find errors and print context
def traverse_tree(node):
errors = []
if node.type == 'ERROR' or node.is_missing:
line_no = node.start_point[0]
errors.append(line_no)
traverse_tree(tree.root_node)
for child in node.children:
errors += traverse_tree(child)
return errors
def main():
"""
@ -52,8 +67,10 @@ def main():
sys.exit(1)
for file_path in sys.argv[1:]:
print(f"Checking file: {file_path}")
parse_file_for_errors(file_path)
code = Path(file_path).read_text()
errors = basic_lint(file_path, code)
if errors:
print(errors)
if __name__ == "__main__":
main()