Refactored MarkdownStream to use rich.live for real-time updates and improved text rendering.

This commit is contained in:
Paul Gauthier 2024-01-23 09:50:23 -08:00
parent b143bc56ac
commit 24f1e01177

View file

@ -5,6 +5,8 @@ import time
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.live import Live
from rich.text import Text
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
@ -56,13 +58,20 @@ def showit(lines):
class MarkdownStream: class MarkdownStream:
live = None
when = 0
min_delay = 0.050
live_window = 6
def __init__(self): def __init__(self):
self.printed = [] self.printed = []
self.when = 0
def update(self, text, final=False, min_delay=0.100, mdargs=None): self.live = Live(Text(''), refresh_per_second=1./self.min_delay)
self.live.start()
def update(self, text, final=False, mdargs=None):
now = time.time() now = time.time()
if not final and now - self.when < min_delay: if not final and now - self.when < self.min_delay:
return return
self.when = now self.when = now
@ -79,34 +88,37 @@ class MarkdownStream:
lines = output.splitlines(keepends=True) lines = output.splitlines(keepends=True)
num_lines = len(lines) num_lines = len(lines)
if not final: if not final:
num_lines -= 4 num_lines -= self.live_window
if num_lines <= 1 and not final: if final or num_lines > 0:
return
num_printed = len(self.printed) num_printed = len(self.printed)
""" show = num_lines - num_printed
if lines[:num_printed] != self.printed:
dump(repr(text))
print('xxx')
print(''.join(self.printed))
print('xxx')
print(''.join(lines))
print('xxx')
sys.exit()
"""
show = num_lines - num_printed if show <= 0:
return
if show <= 0: show = lines[num_printed:num_lines]
return show = ''.join(show)
show = Text.from_ansi(show)
self.live.console.print(show)
self.printed = lines[:num_lines]
if final:
self.live.update(Text(''))
self.live.stop()
else:
rest = lines[num_lines:]
rest = ''.join(rest)
#rest = '...\n' + rest
rest = Text.from_ansi(rest)
self.live.update(rest)
show = lines[num_printed:num_lines]
print("".join(show), end="")
self.printed = lines[:num_lines]
if __name__ == "__main__": if __name__ == "__main__":
@ -115,6 +127,6 @@ if __name__ == "__main__":
pm = MarkdownStream() pm = MarkdownStream()
for i in range(6, len(_text)): for i in range(6, len(_text)):
pm.update(_text[:i]) pm.update(_text[:i])
time.sleep(0.001) time.sleep(0.01)
pm.update(_text, final=True) pm.update(_text, final=True)