diff --git a/aider/coder.py b/aider/coder.py index 846d366e6..8ca5a34b7 100755 --- a/aider/coder.py +++ b/aider/coder.py @@ -33,7 +33,8 @@ class Coder: def __init__(self, main_model, fnames, pretty, history_file, show_diffs, auto_commits, yes): self.abs_fnames = set() - self.yes = yes + self.input = getinput.Input(yes) + self.history_file = history_file self.auto_commits = auto_commits @@ -117,7 +118,7 @@ class Coder: self.console.print(f"Files not tracked in {repo.git_dir}:") for fn in new_files: self.console.print(f" {fn}") - if getinput.confirm_ask("Add them?", default="y"): + if self.input.confirm_ask("Add them?"): for relative_fname in new_files: repo.git.add(relative_fname) self.console.print(f"Added {relative_fname} to the git repo") @@ -203,7 +204,7 @@ class Coder: else: print() - inp = getinput.get_input(self.history_file, self.abs_fnames, self.commands) + inp = self.input.get_input(self.history_file, self.abs_fnames, self.commands) self.num_control_c = 0 @@ -321,7 +322,7 @@ class Coder: for rel_fname in mentioned_rel_fnames: self.console.print(f"{rel_fname}") - if not getinput.confirm_ask("Add {path} to git?", default="y"): + if not self.input.confirm_ask("Add {path} to git?"): return for rel_fname in mentioned_rel_fnames: @@ -403,7 +404,7 @@ class Coder: question = ( f"Allow edits to {path} which was not previously provided?" # noqa: E501 ) - if not getinput.confirm_ask(question, default="y"): + if not self.input.confirm_ask(question): self.console.print(f"[red]Skipping edit to {path}") continue @@ -411,10 +412,7 @@ class Coder: Path(full_path).touch() self.abs_fnames.add(full_path) - if self.repo and getinput.confirm_ask( - f"Add {path} to git?", - default="y", - ): + if self.repo and self.input.confirm_ask(f"Add {path} to git?"): self.repo.git.add(full_path) edited.add(path) @@ -517,7 +515,7 @@ class Coder: self.console.print("Files have uncommitted changes.\n") self.console.print(f"Suggested commit message:\n{commit_message}\n") - res = getinput.prompt_ask( + res = self.input.prompt_ask( "Commit before the chat proceeds [y/n/commit message]?", default=commit_message, ).strip() diff --git a/aider/getinput.py b/aider/getinput.py index 8754530ae..c0ead35c6 100644 --- a/aider/getinput.py +++ b/aider/getinput.py @@ -50,80 +50,81 @@ class FileContentCompleter(Completer): yield Completion(word, start_position=-len(last_word)) -def canned_input(show_prompt): - console = Console() +class Input: + def __init__(self, yes): + self.yes = yes - input_line = input() + def canned_input(self, show_prompt): + console = Console() - console.print(show_prompt, end="", style="green") - for char in input_line: - console.print(char, end="", style="green") - time.sleep(random.uniform(0.01, 0.15)) - console.print() - console.print() - return input_line + input_line = input() + console.print(show_prompt, end="", style="green") + for char in input_line: + console.print(char, end="", style="green") + time.sleep(random.uniform(0.01, 0.15)) + console.print() + console.print() + return input_line -def get_input(history_file, fnames, commands): - fnames = list(fnames) - if len(fnames) > 1: - common_prefix = os.path.commonpath(fnames) - if not common_prefix.endswith(os.path.sep): - common_prefix += os.path.sep - short_fnames = [fname.replace(common_prefix, "", 1) for fname in fnames] - elif len(fnames): - short_fnames = [os.path.basename(fnames[0])] - else: - short_fnames = [] - - show = " ".join(short_fnames) - if len(show) > 10: - show += "\n" - show += "> " - - if not sys.stdin.isatty(): - return canned_input(show) - - inp = "" - multiline_input = False - - style = Style.from_dict({"": "green"}) - - while True: - completer_instance = FileContentCompleter(fnames, commands) - if multiline_input: - show = ". " - - line = prompt( - show, - completer=completer_instance, - history=FileHistory(history_file), - style=style, - reserve_space_for_menu=4, - complete_style=CompleteStyle.MULTI_COLUMN, - ) - if line.strip() == "{" and not multiline_input: - multiline_input = True - continue - elif line.strip() == "}" and multiline_input: - break - elif multiline_input: - inp += line + "\n" + def get_input(self, history_file, fnames, commands): + fnames = list(fnames) + if len(fnames) > 1: + common_prefix = os.path.commonpath(fnames) + if not common_prefix.endswith(os.path.sep): + common_prefix += os.path.sep + short_fnames = [fname.replace(common_prefix, "", 1) for fname in fnames] + elif len(fnames): + short_fnames = [os.path.basename(fnames[0])] else: - inp = line - break + short_fnames = [] - print() - return inp + show = " ".join(short_fnames) + if len(show) > 10: + show += "\n" + show += "> " + if not sys.stdin.isatty(): + return self.canned_input(show) -def confirm_ask(question, default=None, yes=False): - if yes: - return True - return prompt(question + " ", default=default) + inp = "" + multiline_input = False + style = Style.from_dict({"": "green"}) -def prompt_ask(question, default=None, yes=False): - if yes: - return True - return prompt(question + " ", default=default) + while True: + completer_instance = FileContentCompleter(fnames, commands) + if multiline_input: + show = ". " + + line = prompt( + show, + completer=completer_instance, + history=FileHistory(history_file), + style=style, + reserve_space_for_menu=4, + complete_style=CompleteStyle.MULTI_COLUMN, + ) + if line.strip() == "{" and not multiline_input: + multiline_input = True + continue + elif line.strip() == "}" and multiline_input: + break + elif multiline_input: + inp += line + "\n" + else: + inp = line + break + + print() + return inp + + def confirm_ask(self, question, default="y"): + if self.yes: + return True + return prompt(question + " ", default=default) + + def prompt_ask(self, question, default=None): + if self.yes: + return True + return prompt(question + " ", default=default)