Refactor filename handling for clarity

This commit is contained in:
Paul Gauthier 2023-05-10 16:14:32 -07:00
parent d50e9755ed
commit 9ad44d5d7a
2 changed files with 34 additions and 32 deletions

View file

@ -25,7 +25,7 @@ openai.api_key = os.getenv("OPENAI_API_KEY")
class Coder: class Coder:
fnames = set() abs_fnames = set()
last_modified = 0 last_modified = 0
repo = None repo = None
@ -54,7 +54,7 @@ class Coder:
else: else:
self.console.print(f"[red]Loading {fname}") self.console.print(f"[red]Loading {fname}")
self.fnames.add(os.path.abspath(str(fname))) self.abs_fnames.add(os.path.abspath(str(fname)))
self.set_repo() self.set_repo()
if not self.repo: if not self.repo:
@ -67,13 +67,13 @@ class Coder:
self.show_diffs = show_diffs self.show_diffs = show_diffs
def find_common_root(self): def find_common_root(self):
common_prefix = os.path.commonpath(list(self.fnames)) common_prefix = os.path.commonpath(list(self.abs_fnames))
self.root = os.path.dirname(common_prefix) self.root = os.path.dirname(common_prefix)
self.console.print(f"[red]Common root directory: {self.root}") self.console.print(f"[red]Common root directory: {self.root}")
def set_repo(self): def set_repo(self):
repo_paths = [] repo_paths = []
for fname in self.fnames: for fname in self.abs_fnames:
try: try:
repo_path = git.Repo(fname, search_parent_directories=True).git_dir repo_path = git.Repo(fname, search_parent_directories=True).git_dir
repo_paths.append(repo_path) repo_paths.append(repo_path)
@ -94,7 +94,7 @@ class Coder:
self.root = repo.working_tree_dir self.root = repo.working_tree_dir
new_files = [] new_files = []
for fname in self.fnames: for fname in self.abs_fnames:
relative_fname = os.path.relpath(fname, repo.working_tree_dir) relative_fname = os.path.relpath(fname, repo.working_tree_dir)
tracked_files = set(repo.git.ls_files().splitlines()) tracked_files = set(repo.git.ls_files().splitlines())
if relative_fname not in tracked_files: if relative_fname not in tracked_files:
@ -124,13 +124,13 @@ class Coder:
def get_files_content(self): def get_files_content(self):
prompt = "" prompt = ""
for fname in self.fnames: for fname in self.abs_fnames:
relative_fname = os.path.relpath(fname, self.root) relative_fname = os.path.relpath(fname, self.root)
prompt += utils.quoted_file(fname, relative_fname) prompt += utils.quoted_file(fname, relative_fname)
return prompt return prompt
def get_last_modified(self): def get_last_modified(self):
return max(Path(fname).stat().st_mtime for fname in self.fnames) return max(Path(fname).stat().st_mtime for fname in self.abs_fnames)
def get_files_messages(self): def get_files_messages(self):
files_content = prompts.files_content_prefix files_content = prompts.files_content_prefix
@ -180,7 +180,7 @@ class Coder:
else: else:
print() print()
inp = get_input(self.history_file, self.fnames, self.commands) inp = get_input(self.history_file, self.abs_fnames, self.commands)
if inp.startswith("/"): if inp.startswith("/"):
self.commands.run(inp) self.commands.run(inp)
@ -355,7 +355,7 @@ class Coder:
full_path = os.path.abspath(os.path.join(self.root, path)) full_path = os.path.abspath(os.path.join(self.root, path))
if full_path not in self.fnames: if full_path not in self.abs_fnames:
if not Path(full_path).exists(): if not Path(full_path).exists():
question = f"[red]Allow creation of new file {path}?" question = f"[red]Allow creation of new file {path}?"
else: else:
@ -367,7 +367,7 @@ class Coder:
continue continue
Path(full_path).touch() Path(full_path).touch()
self.fnames.add(full_path) self.abs_fnames.add(full_path)
if self.repo and Confirm.ask( if self.repo and Confirm.ask(
f"[red]Add {path} to git?", console=self.console, default="y" f"[red]Add {path} to git?", console=self.console, default="y"
@ -393,7 +393,7 @@ class Coder:
diffs = "" diffs = ""
dirty_fnames = [] dirty_fnames = []
relative_dirty_fnames = [] relative_dirty_fnames = []
for fname in self.fnames: for fname in self.abs_fnames:
relative_fname = os.path.relpath(fname, repo.working_tree_dir) relative_fname = os.path.relpath(fname, repo.working_tree_dir)
if self.pretty: if self.pretty:
these_diffs = repo.git.diff("HEAD", "--color", relative_fname) these_diffs = repo.git.diff("HEAD", "--color", relative_fname)
@ -478,10 +478,14 @@ class Coder:
return commit_hash, commit_message return commit_hash, commit_message
def get_active_files(self): def get_inchat_relative_files(self):
if self.repo: files = [os.path.relpath(fname, self.root) for fname in self.abs_fnames]
files = sorted(self.repo.git.ls_files().splitlines()) return sorted(set(files))
else:
files = self.fnames
return files def get_all_relative_files(self):
if self.repo:
files = self.repo.git.ls_files().splitlines()
else:
files = self.get_inchat_relative_files()
return sorted(set(files))

View file

@ -77,7 +77,7 @@ class Commands:
self.coder.repo.git.add( self.coder.repo.git.add(
*[ *[
os.path.relpath(fname, self.coder.repo.working_tree_dir) os.path.relpath(fname, self.coder.repo.working_tree_dir)
for fname in self.coder.fnames for fname in self.coder.abs_fnames
] ]
) )
self.coder.repo.git.commit("-m", commit_message, "--no-verify") self.coder.repo.git.commit("-m", commit_message, "--no-verify")
@ -128,7 +128,7 @@ class Commands:
def cmd_add(self, args): def cmd_add(self, args):
"Add matching files to the chat" "Add matching files to the chat"
files = self.coder.get_active_files() files = self.coder.get_all_relative_files()
for word in args.split(): for word in args.split():
matched_files = [file for file in files if word in file] matched_files = [file for file in files if word in file]
if not matched_files: if not matched_files:
@ -137,22 +137,19 @@ class Commands:
abs_file_path = os.path.abspath( abs_file_path = os.path.abspath(
os.path.join(self.coder.root, matched_file) os.path.join(self.coder.root, matched_file)
) )
if abs_file_path not in self.coder.fnames: if abs_file_path not in self.coder.abs_fnames:
self.coder.fnames.add(abs_file_path) self.coder.abs_fnames.add(abs_file_path)
self.console.print(f"[red]Added {matched_file} to the chat") self.console.print(f"[red]Added {matched_file} to the chat")
else: else:
self.console.print(f"[red]{matched_file} is already in the chat") self.console.print(f"[red]{matched_file} is already in the chat")
def completions_add(self): def completions_add(self):
return self.coder.get_active_files() res = set(self.coder.get_all_relative_files())
res = res - set(self.coder.get_inchat_relative_files())
return res
def completions_drop(self): def completions_drop(self):
active_files = self.coder.get_active_files() return self.coder.get_inchat_relative_files()
return [
os.path.relpath(file, self.coder.root)
for file in self.coder.fnames
if file in active_files
]
def cmd_drop(self, args): def cmd_drop(self, args):
"Remove matching files from the chat" "Remove matching files from the chat"
@ -160,26 +157,27 @@ class Commands:
for word in args.split(): for word in args.split():
matched_files = [ matched_files = [
file file
for file in self.coder.fnames for file in self.coder.abs_fnames
if word in os.path.relpath(file, self.coder.root) if word in os.path.relpath(file, self.coder.root)
] ]
if not matched_files: if not matched_files:
self.console.print(f"[red]No files matched '{word}'") self.console.print(f"[red]No files matched '{word}'")
for matched_file in matched_files: for matched_file in matched_files:
relative_fname = os.path.relpath(matched_file, self.coder.root) relative_fname = os.path.relpath(matched_file, self.coder.root)
self.coder.fnames.remove(matched_file) self.coder.abs_fnames.remove(matched_file)
self.console.print(f"[red]Removed {relative_fname} from the chat") self.console.print(f"[red]Removed {relative_fname} from the chat")
def cmd_ls(self, args): def cmd_ls(self, args):
"List files and show their chat status" "List files and show their chat status"
files = self.coder.get_active_files() files = self.coder.get_all_relative_files()
self.console.print("[red]Files in chat:\n") self.console.print("[red]Files in chat:\n")
other_files = [] other_files = []
for file in files: for file in files:
abs_file_path = os.path.abspath(os.path.join(self.coder.root, file)) abs_file_path = os.path.abspath(os.path.join(self.coder.root, file))
if abs_file_path in self.coder.fnames: if abs_file_path in self.coder.abs_fnames:
self.console.print(f"[red] {file}") self.console.print(f"[red] {file}")
else: else:
other_files.append(file) other_files.append(file)