Refactored the code to use a new io tool method and removed the tool_error method.

This commit is contained in:
Paul Gauthier 2023-05-12 13:37:22 -07:00
parent 91f23f3700
commit 85515b4788
2 changed files with 27 additions and 21 deletions

View file

@ -63,7 +63,7 @@ class Coder:
else: else:
self.root = os.getcwd() self.root = os.getcwd()
self.console.print(f"Common root directory: {self.root}") self.io.tool(f"Common root directory: {self.root}")
def set_repo(self, cmd_line_fnames): def set_repo(self, cmd_line_fnames):
if not cmd_line_fnames: if not cmd_line_fnames:
@ -74,7 +74,7 @@ class Coder:
repo_paths = [] repo_paths = []
for fname in abs_fnames: for fname in abs_fnames:
if not fname.exists(): if not fname.exists():
self.console.print(f"Creating {fname}") self.io.tool(f"Creating {fname}")
fname.parent.mkdir(parents=True, exist_ok=True) fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch() fname.touch()
try: try:
@ -85,7 +85,7 @@ class Coder:
if fname.is_dir(): if fname.is_dir():
continue continue
self.console.print(f"Loading {fname}") self.io.tool(f"Loading {fname}")
fname = fname.resolve() fname = fname.resolve()
self.abs_fnames.add(str(fname)) self.abs_fnames.add(str(fname))
@ -112,17 +112,17 @@ class Coder:
new_files.append(relative_fname) new_files.append(relative_fname)
if new_files: if new_files:
self.console.print(f"Files not tracked in {repo.git_dir}:") self.io.tool(f"Files not tracked in {repo.git_dir}:")
for fn in new_files: for fn in new_files:
self.console.print(f" {fn}") self.io.tool(f" {fn}")
if self.io.confirm_ask("Add them?"): if self.io.confirm_ask("Add them?"):
for relative_fname in new_files: for relative_fname in new_files:
repo.git.add(relative_fname) repo.git.add(relative_fname)
self.console.print(f"Added {relative_fname} to the git repo") self.io.tool(f"Added {relative_fname} to the git repo")
show_files = ", ".join(new_files) show_files = ", ".join(new_files)
commit_message = f"Initial commit: Added new files to the git repo: {show_files}" commit_message = f"Initial commit: Added new files to the git repo: {show_files}"
repo.git.commit("-m", commit_message, "--no-verify") repo.git.commit("-m", commit_message, "--no-verify")
self.console.print(f"Committed new files with message: {commit_message}") self.io.tool(f"Committed new files with message: {commit_message}")
else: else:
self.io.tool_error("Skipped adding new files to the git repo.") self.io.tool_error("Skipped adding new files to the git repo.")
return return
@ -250,7 +250,7 @@ class Coder:
dict(role="assistant", content=content), dict(role="assistant", content=content),
] ]
self.console.print() self.io.tool()
if interrupted: if interrupted:
return return
@ -317,7 +317,7 @@ class Coder:
return return
for rel_fname in mentioned_rel_fnames: for rel_fname in mentioned_rel_fnames:
self.console.print(f"{rel_fname}") self.io.tool(f"{rel_fname}")
if not self.io.confirm_ask("Add {path} to git?"): if not self.io.confirm_ask("Add {path} to git?"):
return return
@ -414,7 +414,7 @@ class Coder:
edited.add(path) edited.add(path)
if utils.do_replace(full_path, original, updated): if utils.do_replace(full_path, original, updated):
self.console.print(f"Applied edit to {path}") self.io.tool(f"Applied edit to {path}")
else: else:
self.io.tool_error(f"Failed to apply edit to {path}") self.io.tool_error(f"Failed to apply edit to {path}")
@ -445,7 +445,7 @@ class Coder:
commit_message = commit_message.strip().strip('"').strip() commit_message = commit_message.strip().strip('"').strip()
if interrupted: if interrupted:
self.console.print( self.io.tool(
self.io.tool_error("Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.") self.io.tool_error("Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.")
) )
return return
@ -494,7 +494,7 @@ class Coder:
raise ValueError(f"Invalid value for 'which': {which}") raise ValueError(f"Invalid value for 'which': {which}")
if self.show_diffs or ask: if self.show_diffs or ask:
self.console.print(diffs) self.io.tool(diffs)
context = self.get_context_from_history(history) context = self.get_context_from_history(history)
if message: if message:
@ -507,10 +507,10 @@ class Coder:
if ask: if ask:
if which == "repo_files": if which == "repo_files":
self.console.print("Git repo has uncommitted changes.\n") self.io.tool("Git repo has uncommitted changes.\n")
else: else:
self.console.print("Files have uncommitted changes.\n") self.io.tool("Files have uncommitted changes.\n")
self.console.print(f"Suggested commit message:\n{commit_message}\n") self.io.tool(f"Suggested commit message:\n{commit_message}\n")
res = self.io.prompt_ask( res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?", "Commit before the chat proceeds [y/n/commit message]?",
@ -518,7 +518,7 @@ class Coder:
).strip() ).strip()
self.last_asked_for_commit_time = self.get_last_modified() self.last_asked_for_commit_time = self.get_last_modified()
self.console.print() self.io.tool()
if res.lower() in ["n", "no"]: if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.") self.io.tool_error("Skipped commmit.")
@ -531,7 +531,7 @@ class Coder:
full_commit_message = commit_message + "\n\n" + context full_commit_message = commit_message + "\n\n" + context
repo.git.commit("-m", full_commit_message, "--no-verify") repo.git.commit("-m", full_commit_message, "--no-verify")
commit_hash = repo.head.commit.hexsha[:7] commit_hash = repo.head.commit.hexsha[:7]
self.console.print(f"{commit_hash} {commit_message}") self.io.tool(f"{commit_hash} {commit_message}")
return commit_hash, commit_message return commit_hash, commit_message

View file

@ -61,10 +61,6 @@ class InputOutput:
else: else:
self.console = Console(force_terminal=True, no_color=True) self.console = Console(force_terminal=True, no_color=True)
def tool_error(self, message):
message = Text(message)
self.console.print('[red]', message)
def canned_input(self, show_prompt): def canned_input(self, show_prompt):
console = Console() console = Console()
@ -130,6 +126,8 @@ class InputOutput:
print() print()
return inp return inp
## OUTPUT
def confirm_ask(self, question, default="y"): def confirm_ask(self, question, default="y"):
if self.yes: if self.yes:
return True return True
@ -139,3 +137,11 @@ class InputOutput:
if self.yes: if self.yes:
return True return True
return prompt(question + " ", default=default) return prompt(question + " ", default=default)
def tool_error(self, message):
message = Text(message)
self.console.print(message, style="red")
def tool(self, *messages):
messages = list(map(Text, messages))
self.console.print(*messages)