aider/aider/io.py
2024-10-09 14:37:22 -05:00

754 lines
No EOL
25 KiB
Python

import base64
import os
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from prompt_toolkit.completion import Completer, Completion, ThreadedCompleter
from prompt_toolkit.cursor_shapes import ModalCursorShapeConfig
from prompt_toolkit.enums import EditingMode
from prompt_toolkit.history import FileHistory
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
from prompt_toolkit.styles import Style
from pygments.lexers import MarkdownLexer, guess_lexer_for_filename
from pygments.token import Token
from rich.console import Console
from rich.markdown import Markdown
from rich.style import Style as RichStyle
from rich.text import Text
from aider.mdstream import MarkdownStream
from .dump import dump # noqa: F401
from .utils import is_image_file
@dataclass
class ConfirmGroup:
preference: str = None
show_group: bool = True
def __init__(self, items=None):
if items is not None:
self.show_group = len(items) > 1
class AutoCompleter(Completer):
def __init__(
self, root, rel_fnames, addable_rel_fnames, commands, encoding, abs_read_only_fnames=None
):
self.addable_rel_fnames = addable_rel_fnames
self.rel_fnames = rel_fnames
self.encoding = encoding
self.abs_read_only_fnames = abs_read_only_fnames or []
fname_to_rel_fnames = defaultdict(list)
for rel_fname in addable_rel_fnames:
fname = os.path.basename(rel_fname)
if fname != rel_fname:
fname_to_rel_fnames[fname].append(rel_fname)
self.fname_to_rel_fnames = fname_to_rel_fnames
self.words = set()
self.commands = commands
self.command_completions = dict()
if commands:
self.command_names = self.commands.get_commands()
for rel_fname in addable_rel_fnames:
self.words.add(rel_fname)
for rel_fname in rel_fnames:
self.words.add(rel_fname)
all_fnames = [Path(root) / rel_fname for rel_fname in rel_fnames]
if abs_read_only_fnames:
all_fnames.extend(abs_read_only_fnames)
self.all_fnames = all_fnames
self.tokenized = False
def tokenize(self):
if self.tokenized:
return
self.tokenized = True
for fname in self.all_fnames:
try:
with open(fname, "r", encoding=self.encoding) as f:
content = f.read()
except (FileNotFoundError, UnicodeDecodeError, IsADirectoryError):
continue
try:
lexer = guess_lexer_for_filename(fname, content)
except Exception: # On Windows, bad ref to time.clock which is deprecated
continue
tokens = list(lexer.get_tokens(content))
self.words.update(
(token[1], f"`{token[1]}`") for token in tokens if token[0] in Token.Name
)
def get_command_completions(self, document, complete_event, text, words):
if len(words) == 1 and not text[-1].isspace():
partial = words[0].lower()
candidates = [cmd for cmd in self.command_names if cmd.startswith(partial)]
for candidate in sorted(candidates):
yield Completion(candidate, start_position=-len(words[-1]))
return
if len(words) <= 1 or text[-1].isspace():
return
cmd = words[0]
partial = words[-1].lower()
matches, _, _ = self.commands.matching_commands(cmd)
if len(matches) == 1:
cmd = matches[0]
elif cmd not in matches:
return
raw_completer = self.commands.get_raw_completions(cmd)
if raw_completer:
yield from raw_completer(document, complete_event)
return
if cmd not in self.command_completions:
candidates = self.commands.get_completions(cmd)
self.command_completions[cmd] = candidates
else:
candidates = self.command_completions[cmd]
if candidates is None:
return
candidates = [word for word in candidates if partial in word.lower()]
for candidate in sorted(candidates):
yield Completion(candidate, start_position=-len(words[-1]))
def get_completions(self, document, complete_event):
self.tokenize()
text = document.text_before_cursor
words = text.split()
if not words:
return
if text and text[-1].isspace():
# don't keep completing after a space
return
if text[0] == "/":
yield from self.get_command_completions(document, complete_event, text, words)
return
candidates = self.words
candidates.update(set(self.fname_to_rel_fnames))
candidates = [word if type(word) is tuple else (word, word) for word in candidates]
last_word = words[-1]
completions = []
for word_match, word_insert in candidates:
if word_match.lower().startswith(last_word.lower()):
completions.append((word_insert, -len(last_word), word_match))
rel_fnames = self.fname_to_rel_fnames.get(word_match, [])
if rel_fnames:
for rel_fname in rel_fnames:
completions.append((rel_fname, -len(last_word), rel_fname))
for ins, pos, match in sorted(completions):
yield Completion(ins, start_position=pos, display=match)
class InputOutput:
num_error_outputs = 0
num_user_asks = 0
def __init__(
self,
pretty=True,
yes=None,
input_history_file=None,
chat_history_file=None,
input=None,
output=None,
user_input_color="blue",
tool_output_color=None,
tool_error_color="red",
tool_warning_color="#FFA500",
assistant_output_color="blue",
completion_menu_color=None,
completion_menu_bg_color=None,
completion_menu_current_color=None,
completion_menu_current_bg_color=None,
code_theme="default",
encoding="utf-8",
dry_run=False,
llm_history_file=None,
editingmode=EditingMode.EMACS,
):
self.never_prompts = set()
self.editingmode = editingmode
no_color = os.environ.get("NO_COLOR")
if no_color is not None and no_color != "":
pretty = False
self.user_input_color = user_input_color if pretty else None
self.tool_output_color = tool_output_color if pretty else None
self.tool_error_color = tool_error_color if pretty else None
self.tool_warning_color = tool_warning_color if pretty else None
self.assistant_output_color = assistant_output_color
self.completion_menu_color = completion_menu_color if pretty else None
self.completion_menu_bg_color = completion_menu_bg_color if pretty else None
self.completion_menu_current_color = completion_menu_current_color if pretty else None
self.completion_menu_current_bg_color = completion_menu_current_bg_color if pretty else None
self.code_theme = code_theme
self.input = input
self.output = output
self.pretty = pretty
if self.output:
self.pretty = False
self.yes = yes
self.input_history_file = input_history_file
self.llm_history_file = llm_history_file
if chat_history_file is not None:
self.chat_history_file = Path(chat_history_file)
else:
self.chat_history_file = None
self.encoding = encoding
self.dry_run = dry_run
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.append_chat_history(f"\n# aider chat started at {current_time}\n\n")
self.prompt_session = None
if self.pretty:
# Initialize PromptSession
session_kwargs = {
"input": self.input,
"output": self.output,
"lexer": PygmentsLexer(MarkdownLexer),
"editing_mode": self.editingmode,
"cursor": ModalCursorShapeConfig(),
}
if self.input_history_file is not None:
session_kwargs["history"] = FileHistory(self.input_history_file)
try:
self.prompt_session = PromptSession(**session_kwargs)
self.console = Console() # pretty console
except Exception as err:
self.console = Console(force_terminal=False, no_color=True)
self.tool_error(f"Can't initialize prompt toolkit: {err}") # non-pretty
else:
self.console = Console(force_terminal=False, no_color=True) # non-pretty
def _get_style(self):
style_dict = {}
if not self.pretty:
return Style.from_dict(style_dict)
if self.user_input_color:
style_dict.setdefault("", self.user_input_color)
style_dict.update(
{
"pygments.literal.string": f"bold italic {self.user_input_color}",
}
)
# Conditionally add 'completion-menu' style
completion_menu_style = []
if self.completion_menu_bg_color:
completion_menu_style.append(f"bg:{self.completion_menu_bg_color}")
if self.completion_menu_color:
completion_menu_style.append(self.completion_menu_color)
if completion_menu_style:
style_dict["completion-menu"] = " ".join(completion_menu_style)
# Conditionally add 'completion-menu.completion.current' style
completion_menu_current_style = []
if self.completion_menu_current_bg_color:
completion_menu_current_style.append(f"bg:{self.completion_menu_current_bg_color}")
if self.completion_menu_current_color:
completion_menu_current_style.append(self.completion_menu_current_color)
if completion_menu_current_style:
style_dict["completion-menu.completion.current"] = " ".join(
completion_menu_current_style
)
return Style.from_dict(style_dict)
def read_image(self, filename):
try:
with open(str(filename), "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("utf-8")
except OSError as err:
self.tool_error(f"{filename}: unable to read: {err}")
return
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except IsADirectoryError:
self.tool_error(f"{filename}: is a directory")
return
except Exception as e:
self.tool_error(f"{filename}: {e}")
return
def read_text(self, filename):
if is_image_file(filename):
return self.read_image(filename)
try:
with open(str(filename), "r", encoding=self.encoding) as f:
return f.read()
except OSError as err:
self.tool_error(f"{filename}: unable to read: {err}")
return
except FileNotFoundError:
self.tool_error(f"{filename}: file not found error")
return
except IsADirectoryError:
self.tool_error(f"{filename}: is a directory")
return
except UnicodeError as e:
self.tool_error(f"{filename}: {e}")
self.tool_error("Use --encoding to set the unicode encoding.")
return
def write_text(self, filename, content):
if self.dry_run:
return
try:
with open(str(filename), "w", encoding=self.encoding) as f:
f.write(content)
except OSError as err:
self.tool_error(f"Unable to write file {filename}: {err}")
def rule(self):
if self.pretty:
style = dict(style=self.user_input_color) if self.user_input_color else dict()
self.console.rule(**style)
else:
print()
def get_input(
self,
root,
rel_fnames,
addable_rel_fnames,
commands,
abs_read_only_fnames=None,
edit_format=None,
):
self.rule()
rel_fnames = list(rel_fnames)
show = ""
if rel_fnames:
rel_read_only_fnames = [os.path.relpath(fname, root) for fname in (abs_read_only_fnames or [])]
show = self.format_files_for_input(rel_fnames, rel_read_only_fnames)
if edit_format:
show += edit_format
show += "> "
inp = ""
multiline_input = False
style = self._get_style()
completer_instance = ThreadedCompleter(
AutoCompleter(
root,
rel_fnames,
addable_rel_fnames,
commands,
self.encoding,
abs_read_only_fnames=abs_read_only_fnames,
)
)
kb = KeyBindings()
@kb.add("c-space")
def _(event):
"Ignore Ctrl when pressing space bar"
event.current_buffer.insert_text(" ")
@kb.add("escape", "c-m", eager=True)
def _(event):
event.current_buffer.insert_text("\n")
while True:
if multiline_input:
show = ". "
try:
if self.prompt_session:
line = self.prompt_session.prompt(
show,
completer=completer_instance,
reserve_space_for_menu=4,
complete_style=CompleteStyle.MULTI_COLUMN,
style=style,
key_bindings=kb,
)
else:
line = input(show)
except UnicodeEncodeError as err:
self.tool_error(str(err))
return ""
if line and line[0] == "{" and not multiline_input:
multiline_input = True
inp += line[1:] + "\n"
continue
elif line and line[-1] == "}" and multiline_input:
inp += line[:-1] + "\n"
break
elif multiline_input:
inp += line + "\n"
else:
inp = line
break
print()
self.user_input(inp)
return inp
def add_to_input_history(self, inp):
if not self.input_history_file:
return
FileHistory(self.input_history_file).append_string(inp)
# Also add to the in-memory history if it exists
if hasattr(self, "session") and hasattr(self.session, "history"):
self.session.history.append_string(inp)
def get_input_history(self):
if not self.input_history_file:
return []
fh = FileHistory(self.input_history_file)
return fh.load_history_strings()
def log_llm_history(self, role, content):
if not self.llm_history_file:
return
timestamp = datetime.now().isoformat(timespec="seconds")
with open(self.llm_history_file, "a", encoding=self.encoding) as log_file:
log_file.write(f"{role.upper()} {timestamp}\n")
log_file.write(content + "\n")
def user_input(self, inp, log_only=True):
if not log_only:
if self.pretty and self.user_input_color:
style = dict(style=self.user_input_color)
else:
style = dict()
self.console.print(Text(inp), **style)
prefix = "####"
if inp:
hist = inp.splitlines()
else:
hist = ["<blank>"]
hist = f" \n{prefix} ".join(hist)
hist = f"""
{prefix} {hist}"""
self.append_chat_history(hist, linebreak=True)
# OUTPUT
def ai_output(self, content):
hist = "\n" + content.strip() + "\n\n"
self.append_chat_history(hist)
def confirm_ask(
self,
question,
default="y",
subject=None,
explicit_yes_required=False,
group=None,
allow_never=False,
):
self.num_user_asks += 1
question_id = (question, subject)
if question_id in self.never_prompts:
return False
if group and not group.show_group:
group = None
if group:
allow_never = True
valid_responses = ["yes", "no"]
options = " (Y)es/(N)o"
if group:
if not explicit_yes_required:
options += "/(A)ll"
valid_responses.append("all")
options += "/(S)kip all"
valid_responses.append("skip")
if allow_never:
options += "/(D)on't ask again"
valid_responses.append("don't")
question += options + " [Yes]: "
if subject:
self.tool_output()
if "\n" in subject:
lines = subject.splitlines()
max_length = max(len(line) for line in lines)
padded_lines = [line.ljust(max_length) for line in lines]
padded_subject = "\n".join(padded_lines)
self.tool_output(padded_subject, bold=True)
else:
self.tool_output(subject, bold=True)
style = self._get_style()
def is_valid_response(text):
if not text:
return True
return text.lower() in valid_responses
if self.yes is True:
res = "n" if explicit_yes_required else "y"
elif self.yes is False:
res = "n"
elif group and group.preference:
res = group.preference
self.user_input(f"{question}{res}", log_only=False)
else:
while True:
if self.prompt_session:
res = self.prompt_session.prompt(
question,
style=style,
)
else:
res = input(question)
if not res:
res = "y" # Default to Yes if no input
break
res = res.lower()
good = any(valid_response.startswith(res) for valid_response in valid_responses)
if good:
break
error_message = f"Please answer with one of: {', '.join(valid_responses)}"
self.tool_error(error_message)
res = res.lower()[0]
if res == "d" and allow_never:
self.never_prompts.add(question_id)
hist = f"{question.strip()} {res}"
self.append_chat_history(hist, linebreak=True, blockquote=True)
return False
if explicit_yes_required:
is_yes = res == "y"
else:
is_yes = res in ("y", "a")
is_all = res == "a" and group is not None and not explicit_yes_required
is_skip = res == "s" and group is not None
if group:
if is_all and not explicit_yes_required:
group.preference = "all"
elif is_skip:
group.preference = "skip"
hist = f"{question.strip()} {res}"
self.append_chat_history(hist, linebreak=True, blockquote=True)
return is_yes
def prompt_ask(self, question, default="", subject=None):
self.num_user_asks += 1
if subject:
self.tool_output()
self.tool_output(subject, bold=True)
style = self._get_style()
if self.yes is True:
res = "yes"
elif self.yes is False:
res = "no"
else:
if self.prompt_session:
res = self.prompt_session.prompt(question + " ", default=default, style=style)
else:
res = input(question + " ")
hist = f"{question.strip()} {res.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True)
if self.yes in (True, False):
self.tool_output(hist)
return res
def _tool_message(self, message="", strip=True, color=None):
if message.strip():
if "\n" in message:
for line in message.splitlines():
self.append_chat_history(line, linebreak=True, blockquote=True, strip=strip)
else:
hist = message.strip() if strip else message
self.append_chat_history(hist, linebreak=True, blockquote=True)
message = Text(message)
style = dict(style=color) if self.pretty and color else dict()
self.console.print(message, **style)
def tool_error(self, message="", strip=True):
self.num_error_outputs += 1
self._tool_message(message, strip, self.tool_error_color)
def tool_warning(self, message="", strip=True):
self._tool_message(message, strip, self.tool_warning_color)
def tool_output(self, *messages, log_only=False, bold=False):
if messages:
hist = " ".join(messages)
hist = f"{hist.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True)
if log_only:
return
messages = list(map(Text, messages))
style = dict()
if self.pretty:
if self.tool_output_color:
style["color"] = self.tool_output_color
style["reverse"] = bold
style = RichStyle(**style)
self.console.print(*messages, style=style)
def get_assistant_mdstream(self):
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdStream = MarkdownStream(mdargs=mdargs)
return mdStream
def assistant_output(self, message, pretty=None):
show_resp = message
# Coder will force pretty off if fence is not triple-backticks
if pretty is None:
pretty = self.pretty
if pretty:
show_resp = Markdown(
message, style=self.assistant_output_color, code_theme=self.code_theme
)
else:
show_resp = Text(message or "<no response>")
self.console.print(show_resp)
def print(self, message=""):
print(message)
def append_chat_history(self, text, linebreak=False, blockquote=False, strip=True):
if blockquote:
if strip:
text = text.strip()
text = "> " + text
if linebreak:
if strip:
text = text.rstrip()
text = text + " \n"
if not text.endswith("\n"):
text += "\n"
if self.chat_history_file is not None:
try:
with self.chat_history_file.open("a", encoding=self.encoding, errors="ignore") as f:
f.write(text)
except (PermissionError, OSError):
self.tool_error(
f"Warning: Unable to write to chat history file {self.chat_history_file}."
" Permission denied."
)
self.chat_history_file = None # Disable further attempts to write
def format_files_for_input(self, rel_fnames, rel_read_only_fnames):
minimal_unique_fileids = self.compute_minimal_fileids(rel_fnames)
# Format the filename for display in the prompt, with disambiguating path
# in parentheses, if needed.
def format_minimal_fileid(fname):
pth = Path(minimal_unique_fileids[fname])
if len(pth.parts) > 1:
return f"{pth.name} ({'/'.join(pth.parts[:-1])})"
else:
return pth.name
read_only_files = []
for full_path in (rel_read_only_fnames or []):
name = format_minimal_fileid(full_path)
read_only_files.append(f" R {name}")
editable_files = []
for full_path in rel_fnames:
if full_path in rel_read_only_fnames:
continue
name = format_minimal_fileid(full_path)
editable_files.append(f" {name}")
return "\n".join(read_only_files + editable_files) + '\n'
def compute_minimal_fileids(self, rel_fnames):
# First pass: group files with the same name
grouped_fnames = defaultdict(list)
for full_path in rel_fnames:
pth = Path(full_path)
fname = pth.name
grouped_fnames[fname].append(list(pth.parts))
# Second pass: compute the shared prefix of each group of files.
shared_prefixes = {}
for fname, paths in grouped_fnames.items():
shared_prefix = []
while all(len(path) > 1 for path in paths):
next_part = paths[0][0]
if not all(path[0] == next_part for path in paths):
break
shared_prefix.append(next_part)
paths = [path[1:] for path in paths]
shared_prefixes[fname] = Path(*shared_prefix)
# Third pass: subtract the shared prefix from the full path to get the minimal unique id
minimal_unique_ids = {}
for full_path in rel_fnames:
pth = Path(full_path)
fname = pth.name
prefix = shared_prefixes[fname]
minimal_unique_ids[full_path] = str(pth.relative_to(prefix))
return minimal_unique_ids