diff --git a/llmvm/client/printing.py b/llmvm/client/printing.py index 593bc6e..04f6419 100644 --- a/llmvm/client/printing.py +++ b/llmvm/client/printing.py @@ -15,6 +15,7 @@ from rich.console import Console from rich.markdown import Markdown +from rich.syntax import Syntax from rich.theme import Theme from llmvm.common.container import Container @@ -115,6 +116,12 @@ def __init__(self, file=sys.stderr): self.token_color = Container.get_config_variable('client_stream_token_color', default='bright_black') self.thinking_token_color = Container.get_config_variable('client_stream_thinking_token_color', default='cyan') + # state for rich rendering while streaming + self.in_code_block = False + self.code_lang = '' + self.code_lines: list[str] = [] + self.paragraph_lines: list[str] = [] + async def display_image(self, image_bytes): if len(image_bytes) < 10: return @@ -211,6 +218,74 @@ async def write(self, node: AstNode): self.buffer += string self.console.print(string, end='', style=f"{token_color}", highlight=False) + # if stop tokens flush remaining buffer + if isinstance(node, TokenStopNode) or isinstance(node, StreamingStopNode): + await self._flush_buffer() + return + + await self._process_buffer() + + async def _process_buffer(self): + while True: + if self.in_code_block: + if '```' in self.buffer: + before, self.buffer = self.buffer.split('```', 1) + self.code_lines.append(before) + await self._flush_code_block() + self.in_code_block = False + else: + self.code_lines.append(self.buffer) + self.buffer = '' + break + else: + if '```' in self.buffer: + before, self.buffer = self.buffer.split('```', 1) + self.paragraph_lines.append(before) + await self._flush_paragraph() + if '\n' in self.buffer: + lang_line, rest = self.buffer.split('\n', 1) + self.code_lang = lang_line.strip() + self.buffer = rest + else: + self.code_lang = self.buffer.strip() + self.buffer = '' + self.in_code_block = True + elif '\n\n' in self.buffer: + para, self.buffer = self.buffer.split('\n\n', 1) + self.paragraph_lines.append(para) + await self._flush_paragraph() + else: + break + + async def _flush_markdown(self, text: str): + if text: + self.console.print(Markdown(text), end='') + + async def _flush_paragraph(self): + if self.paragraph_lines: + await self._flush_markdown('\n'.join(self.paragraph_lines) + '\n') + self.paragraph_lines = [] + + async def _flush_code_block(self): + code = '\n'.join(self.code_lines) + self.code_lines = [] + lang = self.code_lang if self.code_lang else 'text' + self.code_lang = '' + if code: + syntax = Syntax(code, lang, theme="monokai", background_color="default", word_wrap=True, padding=0) + self.console.print(syntax) + + async def _flush_buffer(self): + if self.in_code_block: + self.code_lines.append(self.buffer) + await self._flush_code_block() + self.in_code_block = False + else: + self.paragraph_lines.append(self.buffer) + await self._flush_paragraph() + self.buffer = '' + + class ConsolePrinter: def __init__(self, file=sys.stdout):