allow directories to be provided on the command line; use them to find the git repo

This commit is contained in:
Paul Gauthier 2023-05-10 19:27:45 -07:00
parent 35a4a14a6a
commit 775fbe95f9

View file

@ -33,7 +33,7 @@ class Coder:
repo = None repo = None
last_aider_commit_hash = None last_aider_commit_hash = None
def __init__(self, main_model, files, pretty, history_file, show_diffs): def __init__(self, main_model, fnames, pretty, history_file, show_diffs):
self.history_file = history_file self.history_file = history_file
if pretty: if pretty:
@ -48,17 +48,8 @@ class Coder:
f"[red bold]This tool will almost certainly fail to work with {main_model}" f"[red bold]This tool will almost certainly fail to work with {main_model}"
) )
for fname in files: self.set_repo(fnames)
fname = Path(fname)
if not fname.exists():
self.console.print(f"[red]Creating {fname}")
fname.touch()
else:
self.console.print(f"[red]Loading {fname}")
self.abs_fnames.add(os.path.abspath(str(fname)))
self.set_repo()
if not self.repo: if not self.repo:
self.console.print( self.console.print(
"[red bold]No suitable git repo, will not automatically commit edits." "[red bold]No suitable git repo, will not automatically commit edits."
@ -73,9 +64,14 @@ class Coder:
self.root = os.path.dirname(common_prefix) self.root = os.path.dirname(common_prefix)
self.console.print(f"[red]Common root directory: {self.root}") self.console.print(f"[red]Common root directory: {self.root}")
def set_repo(self): def set_repo(self, cmd_line_fnames):
abs_fnames = [Path(fn).resolve() for fn in cmd_line_fnames]
repo_paths = [] repo_paths = []
for fname in self.abs_fnames: for fname in abs_fnames:
if not fname.exists():
self.console.print(f"[red]Creating {fname}")
fname.touch()
try: try:
repo_path = git.Repo(fname, search_parent_directories=True).git_dir repo_path = git.Repo(fname, search_parent_directories=True).git_dir
repo_paths.append(repo_path) repo_paths.append(repo_path)
@ -96,7 +92,14 @@ class Coder:
self.root = repo.working_tree_dir self.root = repo.working_tree_dir
new_files = [] new_files = []
for fname in self.abs_fnames: for fname in abs_fnames:
if fname.is_dir():
continue
self.console.print(f"[red]Loading {fname}")
fname = fname.resolve()
self.abs_fnames.add(str(fname))
relative_fname = os.path.relpath(fname, repo.working_tree_dir) relative_fname = os.path.relpath(fname, repo.working_tree_dir)
tracked_files = set(repo.git.ls_files().splitlines()) tracked_files = set(repo.git.ls_files().splitlines())
if relative_fname not in tracked_files: if relative_fname not in tracked_files:
@ -135,7 +138,9 @@ class Coder:
return prompt return prompt
def get_last_modified(self): def get_last_modified(self):
if self.abs_fnames:
return max(Path(fname).stat().st_mtime for fname in self.abs_fnames) return max(Path(fname).stat().st_mtime for fname in self.abs_fnames)
return 0
def get_files_messages(self): def get_files_messages(self):
files_content = prompts.files_content_prefix files_content = prompts.files_content_prefix