wip: Refactor Input class to InputOutput and add tool_error method.

This commit is contained in:
Paul Gauthier 2023-05-12 13:17:37 -07:00
parent 57f1b3139a
commit 1d7fffe8ab
2 changed files with 20 additions and 9 deletions

View file

@ -32,7 +32,7 @@ class Coder:
def __init__(self, main_model, fnames, pretty, history_file, show_diffs, auto_commits, yes): def __init__(self, main_model, fnames, pretty, history_file, show_diffs, auto_commits, yes):
self.abs_fnames = set() self.abs_fnames = set()
self.input = getinput.Input(yes) self.io = getinput.InputOutput(pretty, yes)
self.history_file = history_file self.history_file = history_file
self.auto_commits = auto_commits self.auto_commits = auto_commits
@ -115,7 +115,7 @@ class Coder:
self.console.print(f"Files not tracked in {repo.git_dir}:") self.console.print(f"Files not tracked in {repo.git_dir}:")
for fn in new_files: for fn in new_files:
self.console.print(f" {fn}") self.console.print(f" {fn}")
if self.input.confirm_ask("Add them?"): if self.io.confirm_ask("Add them?"):
for relative_fname in new_files: for relative_fname in new_files:
repo.git.add(relative_fname) repo.git.add(relative_fname)
self.console.print(f"Added {relative_fname} to the git repo") self.console.print(f"Added {relative_fname} to the git repo")
@ -201,7 +201,7 @@ class Coder:
else: else:
print() print()
inp = self.input.get_input(self.history_file, self.abs_fnames, self.commands) inp = self.io.get_input(self.history_file, self.abs_fnames, self.commands)
self.num_control_c = 0 self.num_control_c = 0
@ -319,7 +319,7 @@ class Coder:
for rel_fname in mentioned_rel_fnames: for rel_fname in mentioned_rel_fnames:
self.console.print(f"{rel_fname}") self.console.print(f"{rel_fname}")
if not self.input.confirm_ask("Add {path} to git?"): if not self.io.confirm_ask("Add {path} to git?"):
return return
for rel_fname in mentioned_rel_fnames: for rel_fname in mentioned_rel_fnames:
@ -401,7 +401,7 @@ class Coder:
question = ( question = (
f"Allow edits to {path} which was not previously provided?" # noqa: E501 f"Allow edits to {path} which was not previously provided?" # noqa: E501
) )
if not self.input.confirm_ask(question): if not self.io.confirm_ask(question):
self.console.print(f"[red]Skipping edit to {path}") self.console.print(f"[red]Skipping edit to {path}")
continue continue
@ -409,7 +409,7 @@ class Coder:
Path(full_path).touch() Path(full_path).touch()
self.abs_fnames.add(full_path) self.abs_fnames.add(full_path)
if self.repo and self.input.confirm_ask(f"Add {path} to git?"): if self.repo and self.io.confirm_ask(f"Add {path} to git?"):
self.repo.git.add(full_path) self.repo.git.add(full_path)
edited.add(path) edited.add(path)
@ -512,7 +512,7 @@ class Coder:
self.console.print("Files have uncommitted changes.\n") self.console.print("Files have uncommitted changes.\n")
self.console.print(f"Suggested commit message:\n{commit_message}\n") self.console.print(f"Suggested commit message:\n{commit_message}\n")
res = self.input.prompt_ask( res = self.io.prompt_ask(
"Commit before the chat proceeds [y/n/commit message]?", "Commit before the chat proceeds [y/n/commit message]?",
default=commit_message, default=commit_message,
).strip() ).strip()

View file

@ -8,6 +8,7 @@ from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.shortcuts import CompleteStyle from prompt_toolkit.shortcuts import CompleteStyle
from rich.console import Console from rich.console import Console
from rich.text import Text
import sys import sys
import time import time
import random import random
@ -50,10 +51,20 @@ class FileContentCompleter(Completer):
yield Completion(word, start_position=-len(last_word)) yield Completion(word, start_position=-len(last_word))
class Input: class InputOutput:
def __init__(self, yes): def __init__(self, pretty, yes):
self.pretty = pretty
self.yes = yes self.yes = yes
if pretty:
self.console = Console()
else:
self.console = Console(force_terminal=True, no_color=True)
def tool_error(self, message):
message = Text(message)
self.console.print('[red]', message)
def canned_input(self, show_prompt): def canned_input(self, show_prompt):
console = Console() console = Console()