set mdargs on init

This commit is contained in:
Paul Gauthier 2024-01-23 09:58:57 -08:00
parent da131da427
commit 580c52bd85
2 changed files with 17 additions and 18 deletions

View file

@ -726,7 +726,8 @@ class Coder:
def show_send_output_stream(self, completion):
if self.show_pretty():
mdstream = MarkdownStream()
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdstream = MarkdownStream(mdargs=mdargs)
else:
mdstream = None
@ -773,8 +774,7 @@ class Coder:
if not show_resp:
return
mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
mdstream.update(show_resp, mdargs=mdargs, final=final)
mdstream.update(show_resp, final=final)
def render_incremental_response(self, final):
return self.partial_response_content

View file

@ -4,8 +4,8 @@ import io
import time
from rich.console import Console
from rich.markdown import Markdown
from rich.live import Live
from rich.markdown import Markdown
from rich.text import Text
from aider.dump import dump # noqa: F401
@ -63,16 +63,21 @@ class MarkdownStream:
min_delay = 0.050
live_window = 6
def __init__(self):
def __init__(self, mdargs=None):
self.printed = []
self.live = Live(Text(''), refresh_per_second=1./self.min_delay)
if mdargs:
self.mdargs = mdargs
else:
self.mdargs = dict()
self.live = Live(Text(""), refresh_per_second=1.0 / self.min_delay)
self.live.start()
def __del__(self):
self.live.stop()
def update(self, text, final=False, mdargs=None):
def update(self, text, final=False):
now = time.time()
if not final and now - self.when < self.min_delay:
return
@ -81,10 +86,7 @@ class MarkdownStream:
string_io = io.StringIO()
console = Console(file=string_io, force_terminal=True)
if not mdargs:
mdargs = dict()
markdown = Markdown(text, **mdargs)
markdown = Markdown(text, **self.mdargs)
console.print(markdown)
output = string_io.getvalue()
@ -96,7 +98,6 @@ class MarkdownStream:
num_lines -= self.live_window
if final or num_lines > 0:
num_printed = len(self.printed)
show = num_lines - num_printed
@ -105,25 +106,23 @@ class MarkdownStream:
return
show = lines[num_printed:num_lines]
show = ''.join(show)
show = "".join(show)
show = Text.from_ansi(show)
self.live.console.print(show)
self.printed = lines[:num_lines]
if final:
self.live.update(Text(''))
self.live.update(Text(""))
self.live.stop()
else:
rest = lines[num_lines:]
rest = ''.join(rest)
#rest = '...\n' + rest
rest = "".join(rest)
# rest = '...\n' + rest
rest = Text.from_ansi(rest)
self.live.update(rest)
if __name__ == "__main__":
_text = 5 * _text