allow directories to be provided on the command line; use them to find the git repo

This commit is contained in:
Paul Gauthier 2023-05-10 19:27:45 -07:00
parent 35a4a14a6a
commit 775fbe95f9

View file

@ -33,7 +33,7 @@ class Coder:
repo = None
last_aider_commit_hash = None
def __init__(self, main_model, files, pretty, history_file, show_diffs):
def __init__(self, main_model, fnames, pretty, history_file, show_diffs):
self.history_file = history_file
if pretty:
@ -48,17 +48,8 @@ class Coder:
f"[red bold]This tool will almost certainly fail to work with {main_model}"
)
for fname in files:
fname = Path(fname)
if not fname.exists():
self.console.print(f"[red]Creating {fname}")
fname.touch()
else:
self.console.print(f"[red]Loading {fname}")
self.set_repo(fnames)
self.abs_fnames.add(os.path.abspath(str(fname)))
self.set_repo()
if not self.repo:
self.console.print(
"[red bold]No suitable git repo, will not automatically commit edits."
@ -73,9 +64,14 @@ class Coder:
self.root = os.path.dirname(common_prefix)
self.console.print(f"[red]Common root directory: {self.root}")
def set_repo(self):
def set_repo(self, cmd_line_fnames):
abs_fnames = [Path(fn).resolve() for fn in cmd_line_fnames]
repo_paths = []
for fname in self.abs_fnames:
for fname in abs_fnames:
if not fname.exists():
self.console.print(f"[red]Creating {fname}")
fname.touch()
try:
repo_path = git.Repo(fname, search_parent_directories=True).git_dir
repo_paths.append(repo_path)
@ -96,7 +92,14 @@ class Coder:
self.root = repo.working_tree_dir
new_files = []
for fname in self.abs_fnames:
for fname in abs_fnames:
if fname.is_dir():
continue
self.console.print(f"[red]Loading {fname}")
fname = fname.resolve()
self.abs_fnames.add(str(fname))
relative_fname = os.path.relpath(fname, repo.working_tree_dir)
tracked_files = set(repo.git.ls_files().splitlines())
if relative_fname not in tracked_files:
@ -135,7 +138,9 @@ class Coder:
return prompt
def get_last_modified(self):
return max(Path(fname).stat().st_mtime for fname in self.abs_fnames)
if self.abs_fnames:
return max(Path(fname).stat().st_mtime for fname in self.abs_fnames)
return 0
def get_files_messages(self):
files_content = prompts.files_content_prefix