Refactor the cmd_ funcs to Commands

This commit is contained in:
Paul Gauthier 2023-05-10 15:39:39 -07:00
parent 9bd635a7a4
commit 409cf8d93b
2 changed files with 122 additions and 100 deletions

View file

@ -442,10 +442,11 @@ class Coder:
commit_message = commit_message.strip().strip('"').strip()
if interrupted:
self.console.print("[red]Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.\n")
self.console.print(
"[red]Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.\n"
)
return
if prefix:
commit_message = prefix + commit_message
@ -486,96 +487,3 @@ class Coder:
files = self.fnames
return files
def cmd_commit(self, args):
"Commit outstanding changes to the chat files. Aider will provide a commit message if yo u don't."
if not self.repo:
self.console.print("[red]No git repository found.")
return
if not self.repo.is_dirty():
self.console.print("[red]No changes to commit.")
return
commit_message = args.strip()
if commit_message:
self.repo.git.add(*[os.path.relpath(fname, self.repo.working_tree_dir) for fname in self.fnames])
self.repo.git.commit("-m", commit_message, "--no-verify")
commit_hash = self.repo.head.commit.hexsha[:7]
self.console.print(f"[red]{commit_hash} {commit_message}")
return
self.commit()
def cmd_undo(self, args):
"Undo the last git commit if it was done by aider"
if not self.repo:
self.console.print("[red]No git repository found.")
return
last_commit = self.repo.head.commit
if not last_commit.message.startswith("aider:") or last_commit.hexsha[:7] != self.last_aider_commit_hash:
self.console.print("[red]The last commit was not made by aider in this chat session.")
return
self.repo.git.reset("--hard", "HEAD~1")
self.console.print(f"[red]Undid the last commit: {last_commit.message.strip()}")
def cmd_diff(self, args):
"Display the diff of the last aider commit"
if not self.repo:
self.console.print("[red]No git repository found.")
return
if not self.last_aider_commit_hash:
self.console.print("[red]No previous aider commit found.")
return
commits = f"{self.last_aider_commit_hash}~1"
if self.pretty:
diff = self.repo.git.diff(commits, "--color", self.last_aider_commit_hash)
else:
diff = self.repo.git.diff(commits, self.last_aider_commit_hash)
self.console.print(Text(diff))
def cmd_add(self, args):
"Add matching files to the chat"
files = self.get_active_files()
for word in args.split():
matched_files = [file for file in files if word in file]
if not matched_files:
self.console.print(f"[red]No files matched '{word}'")
for matched_file in matched_files:
abs_file_path = os.path.abspath(os.path.join(self.root, matched_file))
if abs_file_path not in self.fnames:
self.fnames.add(abs_file_path)
self.console.print(f"[red]Added {matched_file} to the chat")
else:
self.console.print(f"[red]{matched_file} is already in the chat")
def cmd_drop(self, args):
"Remove matching files from the chat"
for word in args.split():
matched_files = [file for file in self.fnames if word in os.path.relpath(file, self.root)]
if not matched_files:
self.console.print(f"[red]No files matched '{word}'")
for matched_file in matched_files:
relative_fname = os.path.relpath(matched_file, self.root)
self.fnames.remove(matched_file)
self.console.print(f"[red]Removed {relative_fname} from the chat")
def cmd_ls(self, args):
"List files and show their chat status"
self.console.print(f"* denotes files included in the chat\n")
files = self.get_active_files()
for file in files:
abs_file_path = os.path.abspath(os.path.join(self.root, file))
if abs_file_path in self.fnames:
self.console.print(f"* {file}")
else:
self.console.print(f" {file}")