diff --git a/aider/linter.py b/aider/linter.py index 8fc19a79b..e463a77a6 100644 --- a/aider/linter.py +++ b/aider/linter.py @@ -2,46 +2,61 @@ import os import tree_sitter import sys import warnings +from pathlib import Path + +from aider.dump import dump + +from grep_ast import TreeContext, filename_to_lang # tree_sitter is throwing a FutureWarning warnings.simplefilter("ignore", category=FutureWarning) from tree_sitter_languages import get_language, get_parser # noqa: E402 -def parse_file_for_errors(file_path): +def basic_lint(fname, code): - lang = "python" + lang = filename_to_lang(fname) language = get_language(lang) parser = get_parser(lang) - # Read the file content - with open(file_path, 'r') as file: - content = file.read() + tree = parser.parse(bytes(code, "utf-8")) - tree = parser.parse(bytes(content, "utf8")) + errors = traverse_tree(tree.root_node) + if not errors: + return - # Traverse the tree to find errors and print context - def traverse_tree(node): - if node.type == 'ERROR' or node.is_missing: - error_type = 'Syntax error' if node.type == 'ERROR' else 'Missing element' - start_line = max(0, node.start_point[0] - 3) - end_line = node.end_point[0] + 3 - error_line = node.start_point[0] + 1 + context = TreeContext( + fname, + code, + color=False, + line_number=False, + child_context=False, + last_line=False, + margin=0, + mark_lois=True, + loi_pad=5, + # header_max=30, + show_top_of_file_parent_scope=False, + ) + context.add_lines_of_interest(errors) + context.add_context() + output = "# Syntax Errors found on the lines marked with █\n" + output += fname + ":\n" + output += context.format() - with open(file_path, 'r') as file: - lines = file.readlines() + return output - print(f"{error_type} at line: {error_line}") - print("Context:") - for i in range(start_line, min(end_line, len(lines))): - line_number = i + 1 - prefix = ">> " if line_number == error_line else " " - print(f"{prefix}{line_number}: {lines[i].rstrip()}") - print("\n") - for child in node.children: - traverse_tree(child) +# Traverse the tree to find errors and print context +def traverse_tree(node): + errors = [] + if node.type == 'ERROR' or node.is_missing: + line_no = node.start_point[0] + errors.append(line_no) - traverse_tree(tree.root_node) + for child in node.children: + errors += traverse_tree(child) + + return errors def main(): """ @@ -52,8 +67,10 @@ def main(): sys.exit(1) for file_path in sys.argv[1:]: - print(f"Checking file: {file_path}") - parse_file_for_errors(file_path) + code = Path(file_path).read_text() + errors = basic_lint(file_path, code) + if errors: + print(errors) if __name__ == "__main__": main()