ensure IO obeys pretty flag, catch UnicodeDecodeError on launch and disable pretty

This commit is contained in:
Paul Gauthier 2024-09-02 11:30:48 -07:00
parent 3bf403ba05
commit 4063015560
2 changed files with 40 additions and 15 deletions

View file

@ -262,6 +262,13 @@ class InputOutput:
with open(str(filename), "w", encoding=self.encoding) as f:
f.write(content)
def rule(self):
if self.pretty:
style = dict(style=self.user_input_color) if self.user_input_color else dict()
self.console.rule(**style)
else:
print()
def get_input(
self,
root,
@ -271,11 +278,7 @@ class InputOutput:
abs_read_only_fnames=None,
edit_format=None,
):
if self.pretty:
style = dict(style=self.user_input_color) if self.user_input_color else dict()
self.console.rule(**style)
else:
print()
self.rule()
rel_fnames = list(rel_fnames)
show = ""
@ -288,7 +291,7 @@ class InputOutput:
inp = ""
multiline_input = False
if self.user_input_color:
if self.user_input_color and self.pretty:
style = Style.from_dict(
{
"": self.user_input_color,
@ -380,8 +383,12 @@ class InputOutput:
log_file.write(content + "\n")
def user_input(self, inp, log_only=True):
if not log_only and self.pretty:
style = dict(style=self.user_input_color) if self.user_input_color else dict()
if not log_only:
if self.pretty and self.user_input_color:
style = dict(style=self.user_input_color)
else:
style = dict()
self.console.print(Text(inp), **style)
prefix = "####"
@ -527,7 +534,10 @@ class InputOutput:
self.append_chat_history(hist, linebreak=True, blockquote=True)
message = Text(message)
style = dict(style=self.tool_error_color) if self.tool_error_color else dict()
if self.pretty and self.tool_error_color:
style = dict(style=self.tool_error_color)
else:
style = dict()
self.console.print(message, **style)
def tool_output(self, *messages, log_only=False, bold=False):
@ -536,12 +546,18 @@ class InputOutput:
hist = f"{hist.strip()}"
self.append_chat_history(hist, linebreak=True, blockquote=True)
if not log_only:
messages = list(map(Text, messages))
style = dict(color=self.tool_output_color) if self.tool_output_color else dict()
if log_only:
return
messages = list(map(Text, messages))
style = dict()
if self.pretty:
if self.tool_output_color:
style["color"] = self.tool_output_color
style["reverse"] = bold
style = RichStyle(**style)
self.console.print(*messages, style=style)
style = RichStyle(**style)
self.console.print(*messages, style=style)
def append_chat_history(self, text, linebreak=False, blockquote=False, strip=True):
if blockquote:

View file

@ -398,6 +398,16 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
editingmode=editing_mode,
)
try:
io.tool_output()
io.rule()
except UnicodeEncodeError as err:
if io.pretty:
io.pretty = False
io.tool_error("Terminal does not support pretty output (UnicodeDecodeError)")
else:
raise err
if args.gui and not return_coder:
if not check_streamlit_install(io):
return
@ -588,7 +598,6 @@ def main(argv=None, input=None, output=None, force_git_root=None, return_coder=F
if return_coder:
return coder
io.tool_output()
coder.show_announcements()
if args.show_prompts: