switch to mdstream

This commit is contained in:
Paul Gauthier 2024-01-23 09:20:31 -08:00
parent e50a0e8b09
commit b143bc56ac
2 changed files with 18 additions and 17 deletions

View file

@ -13,13 +13,13 @@ from pathlib import Path
import openai import openai
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from rich.console import Console, Text from rich.console import Console, Text
from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
from aider import models, prompts, utils from aider import models, prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.history import ChatSummary from aider.history import ChatSummary
from aider.io import InputOutput from aider.io import InputOutput
from aider.mdstream import MarkdownStream
from aider.repo import GitRepo from aider.repo import GitRepo
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.sendchat import send_with_retries from aider.sendchat import send_with_retries
@ -725,14 +725,12 @@ class Coder:
self.io.tool_output(tokens) self.io.tool_output(tokens)
def show_send_output_stream(self, completion): def show_send_output_stream(self, completion):
live = None
if self.show_pretty(): if self.show_pretty():
live = Live(vertical_overflow="scroll") mdstream = MarkdownStream()
else:
mdstream = None
try: try:
if live:
live.start()
for chunk in completion: for chunk in completion:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
@ -762,22 +760,21 @@ class Coder:
text = None text = None
if self.show_pretty(): if self.show_pretty():
self.live_incremental_response(live, False) self.live_incremental_response(mdstream, False)
elif text: elif text:
sys.stdout.write(text) sys.stdout.write(text)
sys.stdout.flush() sys.stdout.flush()
finally: finally:
if live: if mdstream:
self.live_incremental_response(live, True) self.live_incremental_response(mdstream, True)
live.stop()
def live_incremental_response(self, live, final): def live_incremental_response(self, mdstream, final):
show_resp = self.render_incremental_response(final) show_resp = self.render_incremental_response(final)
if not show_resp: if not show_resp:
return return
md = Markdown(show_resp, style=self.assistant_output_color, code_theme=self.code_theme) mdargs = dict(style=self.assistant_output_color, code_theme=self.code_theme)
live.update(md) mdstream.update(show_resp, mdargs=mdargs, final=final)
def render_incremental_response(self, final): def render_incremental_response(self, final):
return self.partial_response_content return self.partial_response_content

View file

@ -60,7 +60,7 @@ class MarkdownStream:
self.printed = [] self.printed = []
self.when = 0 self.when = 0
def update(self, text, final=False, min_delay=0.100): def update(self, text, final=False, min_delay=0.100, mdargs=None):
now = time.time() now = time.time()
if not final and now - self.when < min_delay: if not final and now - self.when < min_delay:
return return
@ -69,7 +69,11 @@ class MarkdownStream:
string_io = io.StringIO() string_io = io.StringIO()
console = Console(file=string_io, force_terminal=True) console = Console(file=string_io, force_terminal=True)
markdown = Markdown(text) if not mdargs:
mdargs = dict()
markdown = Markdown(text, **mdargs)
console.print(markdown) console.print(markdown)
output = string_io.getvalue() output = string_io.getvalue()
@ -78,7 +82,7 @@ class MarkdownStream:
if not final: if not final:
num_lines -= 4 num_lines -= 4
if num_lines <= 1: if num_lines <= 1 and not final:
return return
num_printed = len(self.printed) num_printed = len(self.printed)
@ -111,6 +115,6 @@ if __name__ == "__main__":
pm = MarkdownStream() pm = MarkdownStream()
for i in range(6, len(_text)): for i in range(6, len(_text)):
pm.update(_text[:i]) pm.update(_text[:i])
# time.sleep(0.001) time.sleep(0.001)
pm.update(_text, final=True) pm.update(_text, final=True)