set mdargs on init

This commit is contained in:
Paul Gauthier 2024-01-23 09:58:57 -08:00
parent da131da427
commit 580c52bd85
2 changed files with 17 additions and 18 deletions

View file

@ -726,7 +726,8 @@ class Coder:
def show_send_output_stream(self, completion): def show_send_output_stream(self, completion):
if self.show_pretty(): if self.show_pretty():
mdstream = MarkdownStream() mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdstream = MarkdownStream(mdargs=mdargs)
else: else:
mdstream = None mdstream = None
@ -773,8 +774,7 @@ class Coder:
if not show_resp: if not show_resp:
return return
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme) mdstream.update(show_resp, final=final)
mdstream.update(show_resp, mdargs=mdargs, final=final)
def render_incremental_response(self, final): def render_incremental_response(self, final):
return self.partial_response_content return self.partial_response_content

View file

@ -4,8 +4,8 @@ import io
import time import time
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown
from rich.text import Text from rich.text import Text
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
@ -63,16 +63,21 @@ class MarkdownStream:
min_delay = 0.050 min_delay = 0.050
live_window = 6 live_window = 6
def __init__(self): def __init__(self, mdargs=None):
self.printed = [] self.printed = []
self.live = Live(Text(''), refresh_per_second=1./self.min_delay) if mdargs:
self.mdargs = mdargs
else:
self.mdargs = dict()
self.live = Live(Text(""), refresh_per_second=1.0 / self.min_delay)
self.live.start() self.live.start()
def __del__(self): def __del__(self):
self.live.stop() self.live.stop()
def update(self, text, final=False, mdargs=None): def update(self, text, final=False):
now = time.time() now = time.time()
if not final and now - self.when < self.min_delay: if not final and now - self.when < self.min_delay:
return return
@ -81,10 +86,7 @@ class MarkdownStream:
string_io = io.StringIO() string_io = io.StringIO()
console = Console(file=string_io, force_terminal=True) console = Console(file=string_io, force_terminal=True)
if not mdargs: markdown = Markdown(text, **self.mdargs)
mdargs = dict()
markdown = Markdown(text, **mdargs)
console.print(markdown) console.print(markdown)
output = string_io.getvalue() output = string_io.getvalue()
@ -96,7 +98,6 @@ class MarkdownStream:
num_lines -= self.live_window num_lines -= self.live_window
if final or num_lines > 0: if final or num_lines > 0:
num_printed = len(self.printed) num_printed = len(self.printed)
show = num_lines - num_printed show = num_lines - num_printed
@ -105,25 +106,23 @@ class MarkdownStream:
return return
show = lines[num_printed:num_lines] show = lines[num_printed:num_lines]
show = ''.join(show) show = "".join(show)
show = Text.from_ansi(show) show = Text.from_ansi(show)
self.live.console.print(show) self.live.console.print(show)
self.printed = lines[:num_lines] self.printed = lines[:num_lines]
if final: if final:
self.live.update(Text('')) self.live.update(Text(""))
self.live.stop() self.live.stop()
else: else:
rest = lines[num_lines:] rest = lines[num_lines:]
rest = ''.join(rest) rest = "".join(rest)
#rest = '...\n' + rest # rest = '...\n' + rest
rest = Text.from_ansi(rest) rest = Text.from_ansi(rest)
self.live.update(rest) self.live.update(rest)
if __name__ == "__main__": if __name__ == "__main__":
_text = 5 * _text _text = 5 * _text