refactor: extract common code from tool_warning and tool_error

This commit is contained in:
Paul Gauthier (aider) 2024-09-03 12:41:05 -07:00
parent b4ba159f27
commit b088627fcc

View file

@ -524,45 +524,25 @@ class InputOutput:
return res
def _tool_message(self, message="", strip=True, color=None):
if message.strip():
if "\n" in message:
for line in message.splitlines():
self.append_chat_history(line, linebreak=True, blockquote=True, strip=strip)
else:
hist = message.strip() if strip else message
self.append_chat_history(hist, linebreak=True, blockquote=True)
message = Text(message)
style = dict(style=color) if self.pretty and color else dict()
self.console.print(message, **style)
def tool_error(self, message="", strip=True):
self.num_error_outputs += 1
if message.strip():
if "\n" in message:
for line in message.splitlines():
self.append_chat_history(line, linebreak=True, blockquote=True, strip=strip)
else:
if strip:
hist = message.strip()
else:
hist = message
self.append_chat_history(hist, linebreak=True, blockquote=True)
message = Text(message)
if self.pretty and self.tool_error_color:
style = dict(style=self.tool_error_color)
else:
style = dict()
self.console.print(message, **style)
self._tool_message(message, strip, self.tool_error_color)
def tool_warning(self, message="", strip=True):
if message.strip():
if "\n" in message:
for line in message.splitlines():
self.append_chat_history(line, linebreak=True, blockquote=True, strip=strip)
else:
if strip:
hist = message.strip()
else:
hist = message
self.append_chat_history(hist, linebreak=True, blockquote=True)
message = Text(message)
if self.pretty and self.tool_warning_color:
style = dict(style=self.tool_warning_color)
else:
style = dict()
self.console.print(message, **style)
self._tool_message(message, strip, self.tool_warning_color)
def tool_output(self, *messages, log_only=False, bold=False):
if messages: