feat: Add mermaid.live link generation for markdown diagrams

This commit is contained in:
Mitchell Gordon (aider) 2024-11-25 13:41:58 -05:00 committed by Mitchell Gordon
parent c395be252e
commit 30f47b72da

View file

@ -2,6 +2,10 @@
import io
import time
import base64
import json
import zlib
import re
from rich.console import Console
from rich.live import Live
@ -71,11 +75,40 @@ class MarkdownStream:
self.mdargs = mdargs
else:
self.mdargs = dict()
self.mermaid_pattern = re.compile(r'```mermaid\n(.*?)\n```', re.DOTALL)
# Initialize rich Live display with empty text
self.live = Live(Text(""), refresh_per_second=1.0 / self.min_delay)
self.live.start()
def _generate_mermaid_link(self, graph_markdown):
"""Generate a mermaid.live link for the given graph markdown"""
def js_string_to_byte(data):
return bytes(data, 'ascii')
def js_bytes_to_string(data):
return data.decode('ascii')
def js_btoa(data):
return base64.b64encode(data)
def pako_deflate(data):
compress = zlib.compressobj(9, zlib.DEFLATED, 15, 8, zlib.Z_DEFAULT_STRATEGY)
compressed_data = compress.compress(data)
compressed_data += compress.flush()
return compressed_data
j_graph = {
"code": graph_markdown,
"mermaid": {"theme": "default"}
}
byte_str = js_string_to_byte(json.dumps(j_graph))
deflated = pako_deflate(byte_str)
d_encode = js_btoa(deflated)
link = 'http://mermaid.live/view#pako:' + js_bytes_to_string(d_encode)
return link
def _render_markdown_to_lines(self, text):
"""Render markdown text to a list of lines.
@ -134,6 +167,24 @@ class MarkdownStream:
# Set min_delay to render time plus a small buffer
self.min_delay = min(max(render_time * 10, 1.0 / 20), 2)
# Process mermaid diagrams
processed_text = text
for match in self.mermaid_pattern.finditer(text):
mermaid_code = match.group(1)
link = self._generate_mermaid_link(mermaid_code)
# Add the link after the mermaid block
diagram_end = match.end()
processed_text = (
processed_text[:diagram_end] +
f"\n\n[View diagram]({link})\n" +
processed_text[diagram_end:]
)
string_io = io.StringIO()
console = Console(file=string_io, force_terminal=True)
markdown = Markdown(processed_text, **self.mdargs)
num_lines = len(lines)
# How many lines have "left" the live window and are now considered stable?