From 464c3e29e1ce77bc8af7ee4171d068d2dbbd52a9 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 8 Oct 2024 22:10:58 -0500 Subject: [PATCH] add format_files_for_input and compute_minimal_fileids --- aider/io.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/aider/io.py b/aider/io.py index 81829a83b..fca30499c 100644 --- a/aider/io.py +++ b/aider/io.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path +from jinja2.lexer import TOKEN_DOT +from numpy.distutils.misc_util import rel_path from prompt_toolkit.completion import Completer, Completion, ThreadedCompleter from prompt_toolkit.cursor_shapes import ModalCursorShapeConfig from prompt_toolkit.enums import EditingMode @@ -358,7 +360,8 @@ class InputOutput: rel_fnames = list(rel_fnames) show = "" if rel_fnames: - show = " ".join(rel_fnames) + "\n" + rel_read_only_fnames = [os.path.relpath(fname, root) for fname in (abs_read_only_fnames or [])] + show = self.format_files_for_input(rel_fnames, rel_read_only_fnames) if edit_format: show += edit_format show += "> " @@ -695,3 +698,59 @@ class InputOutput: " Permission denied." ) self.chat_history_file = None # Disable further attempts to write + + def format_files_for_input(self, rel_fnames, rel_read_only_fnames): + minimal_unique_fileids = self.compute_minimal_fileids(rel_fnames) + + # Format the filename for display in the prompt, with disambiguating path + # in parentheses, if needed. + def format_minimal_fileid(fname): + pth = Path(minimal_unique_fileids[fname]) + if len(pth.parts) > 1: + return f"{pth.name} ({'/'.join(pth.parts[:-1])})" + else: + return pth.name + + read_only_files = [] + for full_path in (rel_read_only_fnames or []): + name = format_minimal_fileid(full_path) + read_only_files.append(f" R {name}") + + editable_files = [] + for full_path in rel_fnames: + if full_path in rel_read_only_fnames: + continue + name = format_minimal_fileid(full_path) + editable_files.append(f" {name}") + + return "\n".join(read_only_files + editable_files) + '\n' + + def compute_minimal_fileids(self, rel_fnames): + # First pass: group files with the same name + grouped_fnames = defaultdict(list) + for full_path in rel_fnames: + pth = Path(full_path) + fname = pth.name + grouped_fnames[fname].append(list(pth.parts)) + + # Second pass: compute the shared prefix of each group of files. + shared_prefixes = {} + for fname, paths in grouped_fnames.items(): + shared_prefix = [] + while all(len(path) > 1 for path in paths): + next_part = paths[0][0] + if not all(path[0] == next_part for path in paths): + break + shared_prefix.append(next_part) + paths = [path[1:] for path in paths] + shared_prefixes[fname] = Path(*shared_prefix) + + # Third pass: subtract the shared prefix from the full path to get the minimal unique id + minimal_unique_ids = {} + for full_path in rel_fnames: + pth = Path(full_path) + fname = pth.name + prefix = shared_prefixes[fname] + minimal_unique_ids[full_path] = str(pth.relative_to(prefix)) + + return minimal_unique_ids \ No newline at end of file