Refactored the code to use a new io tool method and removed the tool_error method.

This commit is contained in:
Paul Gauthier 2023-05-12 13:37:22 -07:00
parent 91f23f3700
commit 85515b4788
2 changed files with 27 additions and 21 deletions

View file

@ -63,7 +63,7 @@ class Coder:
else:
self.root = os.getcwd()
self.console.print(f"Common root directory: {self.root}")
self.io.tool(f"Common root directory: {self.root}")
def set_repo(self, cmd_line_fnames):
if not cmd_line_fnames:
@ -74,7 +74,7 @@ class Coder:
repo_paths = []
for fname in abs_fnames:
if not fname.exists():
self.console.print(f"Creating {fname}")
self.io.tool(f"Creating {fname}")
fname.parent.mkdir(parents=True, exist_ok=True)
fname.touch()
try:
@ -85,7 +85,7 @@ class Coder:
if fname.is_dir():
continue
self.console.print(f"Loading {fname}")
self.io.tool(f"Loading {fname}")
fname = fname.resolve()
self.abs_fnames.add(str(fname))
@ -112,17 +112,17 @@ class Coder:
new_files.append(relative_fname)
if new_files:
self.console.print(f"Files not tracked in {repo.git_dir}:")
self.io.tool(f"Files not tracked in {repo.git_dir}:")
for fn in new_files:
self.console.print(f" {fn}")
self.io.tool(f" {fn}")
if self.io.confirm_ask("Add them?"):
for relative_fname in new_files:
repo.git.add(relative_fname)
self.console.print(f"Added {relative_fname} to the git repo")
self.io.tool(f"Added {relative_fname} to the git repo")
show_files = ", ".join(new_files)
commit_message = f"Initial commit: Added new files to the git repo: {show_files}"
repo.git.commit("-m", commit_message, "--no-verify")
self.console.print(f"Committed new files with message: {commit_message}")
self.io.tool(f"Committed new files with message: {commit_message}")
else:
self.io.tool_error("Skipped adding new files to the git repo.")
return
@ -250,7 +250,7 @@ class Coder:
dict(role="assistant", content=content),
]
self.console.print()
self.io.tool()
if interrupted:
return
@ -317,7 +317,7 @@ class Coder:
return
for rel_fname in mentioned_rel_fnames:
self.console.print(f"{rel_fname}")
self.io.tool(f"{rel_fname}")
if not self.io.confirm_ask("Add {path} to git?"):
return
@ -414,7 +414,7 @@ class Coder:
edited.add(path)
if utils.do_replace(full_path, original, updated):
self.console.print(f"Applied edit to {path}")
self.io.tool(f"Applied edit to {path}")
else:
self.io.tool_error(f"Failed to apply edit to {path}")
@ -445,7 +445,7 @@ class Coder:
commit_message = commit_message.strip().strip('"').strip()
if interrupted:
self.console.print(
self.io.tool(
self.io.tool_error("Unable to get commit message from gpt-3.5-turbo. Use /commit to try again.")
)
return
@ -494,7 +494,7 @@ class Coder:
raise ValueError(f"Invalid value for 'which': {which}")
if self.show_diffs or ask:
self.console.print(diffs)
self.io.tool(diffs)
context = self.get_context_from_history(history)
if message:
@ -507,10 +507,10 @@ class Coder:
if ask:
if which == "repo_files":
self.console.print("Git repo has uncommitted changes.\n")
self.io.tool("Git repo has uncommitted changes.\n")
else:
self.console.print("Files have uncommitted changes.\n")
self.console.print(f"Suggested commit message:\n{commit_message}\n")
self.io.tool("Files have uncommitted changes.\n")
self.io.tool(f"Suggested commit message:\n{commit_message}\n")
res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?",
@ -518,7 +518,7 @@ class Coder:
).strip()
self.last_asked_for_commit_time = self.get_last_modified()
self.console.print()
self.io.tool()
if res.lower() in ["n", "no"]:
self.io.tool_error("Skipped commmit.")
@ -531,7 +531,7 @@ class Coder:
full_commit_message = commit_message + "\n\n" + context
repo.git.commit("-m", full_commit_message, "--no-verify")
commit_hash = repo.head.commit.hexsha[:7]
self.console.print(f"{commit_hash} {commit_message}")
self.io.tool(f"{commit_hash} {commit_message}")
return commit_hash, commit_message

View file

@ -61,10 +61,6 @@ class InputOutput:
else:
self.console = Console(force_terminal=True, no_color=True)
def tool_error(self, message):
message = Text(message)
self.console.print('[red]', message)
def canned_input(self, show_prompt):
console = Console()
@ -130,6 +126,8 @@ class InputOutput:
print()
return inp
## OUTPUT
def confirm_ask(self, question, default="y"):
if self.yes:
return True
@ -139,3 +137,11 @@ class InputOutput:
if self.yes:
return True
return prompt(question + " ", default=default)
def tool_error(self, message):
message = Text(message)
self.console.print(message, style="red")
def tool(self, *messages):
messages = list(map(Text, messages))
self.console.print(*messages)