From 775fbe95f9237b7b7fa156ab3067cfbdefe7e713 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Wed, 10 May 2023 19:27:45 -0700 Subject: [PATCH] allow directories to be provided on the command line; use them to find the git repo --- aider/coder.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/aider/coder.py b/aider/coder.py index c07f7afe9..d8e3f3ba5 100755 --- a/aider/coder.py +++ b/aider/coder.py @@ -33,7 +33,7 @@ class Coder: repo = None last_aider_commit_hash = None - def __init__(self, main_model, files, pretty, history_file, show_diffs): + def __init__(self, main_model, fnames, pretty, history_file, show_diffs): self.history_file = history_file if pretty: @@ -48,17 +48,8 @@ class Coder: f"[red bold]This tool will almost certainly fail to work with {main_model}" ) - for fname in files: - fname = Path(fname) - if not fname.exists(): - self.console.print(f"[red]Creating {fname}") - fname.touch() - else: - self.console.print(f"[red]Loading {fname}") + self.set_repo(fnames) - self.abs_fnames.add(os.path.abspath(str(fname))) - - self.set_repo() if not self.repo: self.console.print( "[red bold]No suitable git repo, will not automatically commit edits." @@ -73,9 +64,14 @@ class Coder: self.root = os.path.dirname(common_prefix) self.console.print(f"[red]Common root directory: {self.root}") - def set_repo(self): + def set_repo(self, cmd_line_fnames): + abs_fnames = [Path(fn).resolve() for fn in cmd_line_fnames] + repo_paths = [] - for fname in self.abs_fnames: + for fname in abs_fnames: + if not fname.exists(): + self.console.print(f"[red]Creating {fname}") + fname.touch() try: repo_path = git.Repo(fname, search_parent_directories=True).git_dir repo_paths.append(repo_path) @@ -96,7 +92,14 @@ class Coder: self.root = repo.working_tree_dir new_files = [] - for fname in self.abs_fnames: + for fname in abs_fnames: + if fname.is_dir(): + continue + self.console.print(f"[red]Loading {fname}") + + fname = fname.resolve() + self.abs_fnames.add(str(fname)) + relative_fname = os.path.relpath(fname, repo.working_tree_dir) tracked_files = set(repo.git.ls_files().splitlines()) if relative_fname not in tracked_files: @@ -135,7 +138,9 @@ class Coder: return prompt def get_last_modified(self): - return max(Path(fname).stat().st_mtime for fname in self.abs_fnames) + if self.abs_fnames: + return max(Path(fname).stat().st_mtime for fname in self.abs_fnames) + return 0 def get_files_messages(self): files_content = prompts.files_content_prefix