add format_files_for_input and compute_minimal_fileids

This commit is contained in:
Jonathan Ellis 2024-10-08 22:10:58 -05:00
parent 0fe5247d4c
commit 464c3e29e1

View file

@ -5,6 +5,8 @@ from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from jinja2.lexer import TOKEN_DOT
from numpy.distutils.misc_util import rel_path
from prompt_toolkit.completion import Completer, Completion, ThreadedCompleter
from prompt_toolkit.cursor_shapes import ModalCursorShapeConfig
from prompt_toolkit.enums import EditingMode
@ -358,7 +360,8 @@ class InputOutput:
rel_fnames = list(rel_fnames)
show = ""
if rel_fnames:
show = " ".join(rel_fnames) + "\n"
rel_read_only_fnames = [os.path.relpath(fname, root) for fname in (abs_read_only_fnames or [])]
show = self.format_files_for_input(rel_fnames, rel_read_only_fnames)
if edit_format:
show += edit_format
show += "> "
@ -695,3 +698,59 @@ class InputOutput:
" Permission denied."
)
self.chat_history_file = None # Disable further attempts to write
def format_files_for_input(self, rel_fnames, rel_read_only_fnames):
minimal_unique_fileids = self.compute_minimal_fileids(rel_fnames)
# Format the filename for display in the prompt, with disambiguating path
# in parentheses, if needed.
def format_minimal_fileid(fname):
pth = Path(minimal_unique_fileids[fname])
if len(pth.parts) > 1:
return f"{pth.name} ({'/'.join(pth.parts[:-1])})"
else:
return pth.name
read_only_files = []
for full_path in (rel_read_only_fnames or []):
name = format_minimal_fileid(full_path)
read_only_files.append(f" R {name}")
editable_files = []
for full_path in rel_fnames:
if full_path in rel_read_only_fnames:
continue
name = format_minimal_fileid(full_path)
editable_files.append(f" {name}")
return "\n".join(read_only_files + editable_files) + '\n'
def compute_minimal_fileids(self, rel_fnames):
# First pass: group files with the same name
grouped_fnames = defaultdict(list)
for full_path in rel_fnames:
pth = Path(full_path)
fname = pth.name
grouped_fnames[fname].append(list(pth.parts))
# Second pass: compute the shared prefix of each group of files.
shared_prefixes = {}
for fname, paths in grouped_fnames.items():
shared_prefix = []
while all(len(path) > 1 for path in paths):
next_part = paths[0][0]
if not all(path[0] == next_part for path in paths):
break
shared_prefix.append(next_part)
paths = [path[1:] for path in paths]
shared_prefixes[fname] = Path(*shared_prefix)
# Third pass: subtract the shared prefix from the full path to get the minimal unique id
minimal_unique_ids = {}
for full_path in rel_fnames:
pth = Path(full_path)
fname = pth.name
prefix = shared_prefixes[fname]
minimal_unique_ids[full_path] = str(pth.relative_to(prefix))
return minimal_unique_ids