Merge pull request #1402 from caseymcc/io_assistant_output

Modify output from Assistant and Commands to go through InputOutput
This commit is contained in:
paul-gauthier 2024-09-10 15:08:10 -07:00 committed by GitHub
commit d1384e9d5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 34 additions and 32 deletions

View file

@ -18,16 +18,12 @@ from datetime import datetime
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path from pathlib import Path
from rich.console import Console, Text
from rich.markdown import Markdown
from aider import __version__, models, prompts, urls, utils from aider import __version__, models, prompts, urls, utils
from aider.commands import Commands from aider.commands import Commands
from aider.history import ChatSummary from aider.history import ChatSummary
from aider.io import ConfirmGroup, InputOutput from aider.io import ConfirmGroup, InputOutput
from aider.linter import Linter from aider.linter import Linter
from aider.llm import litellm from aider.llm import litellm
from aider.mdstream import MarkdownStream
from aider.repo import ANY_GIT_ERROR, GitRepo from aider.repo import ANY_GIT_ERROR, GitRepo
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.run_cmd import run_cmd from aider.run_cmd import run_cmd
@ -241,8 +237,6 @@ class Coder:
dry_run=False, dry_run=False,
map_tokens=1024, map_tokens=1024,
verbose=False, verbose=False,
assistant_output_color="blue",
code_theme="default",
stream=True, stream=True,
use_git=True, use_git=True,
cur_messages=None, cur_messages=None,
@ -315,17 +309,10 @@ class Coder:
self.auto_commits = auto_commits self.auto_commits = auto_commits
self.dirty_commits = dirty_commits self.dirty_commits = dirty_commits
self.assistant_output_color = assistant_output_color
self.code_theme = code_theme
self.dry_run = dry_run self.dry_run = dry_run
self.pretty = self.io.pretty self.pretty = self.io.pretty
if self.pretty:
self.console = Console()
else:
self.console = Console(force_terminal=False, no_color=True)
self.main_model = main_model self.main_model = main_model
if cache_prompts and self.main_model.cache_control: if cache_prompts and self.main_model.cache_control:
@ -1107,11 +1094,7 @@ class Coder:
utils.show_messages(messages, functions=self.functions) utils.show_messages(messages, functions=self.functions)
self.multi_response_content = "" self.multi_response_content = ""
if self.show_pretty() and self.stream: self.mdstream=self.io.assistant_output("", self.stream)
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
self.mdstream = MarkdownStream(mdargs=mdargs)
else:
self.mdstream = None
retry_delay = 0.125 retry_delay = 0.125
@ -1463,14 +1446,7 @@ class Coder:
raise Exception("No data found in LLM response!") raise Exception("No data found in LLM response!")
show_resp = self.render_incremental_response(True) show_resp = self.render_incremental_response(True)
if self.show_pretty(): self.io.assistant_output(show_resp)
show_resp = Markdown(
show_resp, style=self.assistant_output_color, code_theme=self.code_theme
)
else:
show_resp = Text(show_resp or "<no response>")
self.io.console.print(show_resp)
if ( if (
hasattr(completion.choices[0], "finish_reason") hasattr(completion.choices[0], "finish_reason")

View file

@ -562,8 +562,7 @@ class Commands:
"HEAD", "HEAD",
) )
# don't use io.tool_output() because we don't want to log or further colorize self.io.print(diff)
print(diff)
def quote_fname(self, fname): def quote_fname(self, fname):
if " " in fname and '"' not in fname: if " " in fname and '"' not in fname:
@ -1030,9 +1029,9 @@ class Commands:
if text: if text:
self.io.add_to_input_history(text) self.io.add_to_input_history(text)
print() self.io.print()
self.io.user_input(text, log_only=False) self.io.user_input(text, log_only=False)
print() self.io.print()
return text return text

View file

@ -17,6 +17,8 @@ from pygments.token import Token
from rich.console import Console from rich.console import Console
from rich.style import Style as RichStyle from rich.style import Style as RichStyle
from rich.text import Text from rich.text import Text
from rich.markdown import Markdown
from aider.mdstream import MarkdownStream
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
from .utils import is_image_file from .utils import is_image_file
@ -176,6 +178,8 @@ class InputOutput:
tool_output_color=None, tool_output_color=None,
tool_error_color="red", tool_error_color="red",
tool_warning_color="#FFA500", tool_warning_color="#FFA500",
assistant_output_color="blue",
code_theme="default",
encoding="utf-8", encoding="utf-8",
dry_run=False, dry_run=False,
llm_history_file=None, llm_history_file=None,
@ -190,6 +194,8 @@ class InputOutput:
self.tool_output_color = tool_output_color if pretty else None self.tool_output_color = tool_output_color if pretty else None
self.tool_error_color = tool_error_color if pretty else None self.tool_error_color = tool_error_color if pretty else None
self.tool_warning_color = tool_warning_color if pretty else None self.tool_warning_color = tool_warning_color if pretty else None
self.assistant_output_color = assistant_output_color
self.code_theme = code_theme
self.input = input self.input = input
self.output = output self.output = output
@ -580,6 +586,27 @@ class InputOutput:
style = RichStyle(**style) style = RichStyle(**style)
self.console.print(*messages, style=style) self.console.print(*messages, style=style)
def assistant_output(self, message, stream=False):
mdStream = None
show_resp = message
if self.pretty:
if stream:
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdStream = MarkdownStream(mdargs=mdargs)
else:
show_resp = Markdown(
message, style=self.assistant_output_color, code_theme=self.code_theme
)
else:
show_resp = Text(message or "<no response>")
self.console.print(show_resp)
return mdStream
def print(self, message=""):
print(message)
def append_chat_history(self, text, linebreak=False, blockquote=False, strip=True): def append_chat_history(self, text, linebreak=False, blockquote=False, strip=True):
if blockquote: if blockquote:
if strip: if strip:

View file

@ -401,6 +401,8 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
user_input_color=args.user_input_color, user_input_color=args.user_input_color,
tool_output_color=args.tool_output_color, tool_output_color=args.tool_output_color,
tool_error_color=args.tool_error_color, tool_error_color=args.tool_error_color,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
dry_run=args.dry_run, dry_run=args.dry_run,
encoding=args.encoding, encoding=args.encoding,
llm_history_file=args.llm_history_file, llm_history_file=args.llm_history_file,
@ -584,8 +586,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
dry_run=args.dry_run, dry_run=args.dry_run,
map_tokens=args.map_tokens, map_tokens=args.map_tokens,
verbose=args.verbose, verbose=args.verbose,
assistant_output_color=args.assistant_output_color,
code_theme=args.code_theme,
stream=args.stream, stream=args.stream,
use_git=args.git, use_git=args.git,
restore_chat_history=args.restore_chat_history, restore_chat_history=args.restore_chat_history,