diff --git a/llmvm/client/printing.py b/llmvm/client/printing.py index 593bc6e..f3cc061 100644 --- a/llmvm/client/printing.py +++ b/llmvm/client/printing.py @@ -15,6 +15,7 @@ from rich.console import Console from rich.markdown import Markdown +from rich.syntax import Syntax from rich.theme import Theme from llmvm.common.container import Container @@ -112,8 +113,17 @@ def __init__(self, file=sys.stderr): self.buffer = '' self.console = Console(file=file) self.markdown_mode = False - self.token_color = Container.get_config_variable('client_stream_token_color', default='bright_black') - self.thinking_token_color = Container.get_config_variable('client_stream_thinking_token_color', default='cyan') + self.token_color = Container.get_config_variable( + 'client_stream_token_color', default='bright_black' + ) + self.thinking_token_color = Container.get_config_variable( + 'client_stream_thinking_token_color', default='cyan' + ) + + # state for line-by-line rich rendering + self.in_code_block = False + self.code_lang = '' + self.code_lines: list[str] = [] async def display_image(self, image_bytes): if len(image_bytes) < 10: @@ -186,6 +196,14 @@ async def display_image(self, image_bytes): except Exception as e: return + def _erase_lines(self, count: int): + for _ in range(count): + self.console.file.write("\x1b[1A\r\x1b[2K") + self.console.file.flush() + + def _erase_last_line(self): + self._erase_lines(1) + async def write(self, node: AstNode): if logging.level <= 20: # INFO token_color = self.token_color @@ -211,6 +229,82 @@ async def write(self, node: AstNode): self.buffer += string self.console.print(string, end='', style=f"{token_color}", highlight=False) + # if stop tokens flush remaining buffer + if isinstance(node, TokenStopNode) or isinstance(node, StreamingStopNode): + await self._flush_buffer() + return + + await self._process_buffer() + + async def _process_buffer(self): + while True: + if self.in_code_block: + if '```' in self.buffer: + before, self.buffer = self.buffer.split('```', 1) + self.code_lines.append(before) + await self._flush_code_block() + self.in_code_block = False + elif '\n' in self.buffer: + line, self.buffer = self.buffer.split('\n', 1) + self.code_lines.append(line) + else: + break + else: + if '```' in self.buffer: + before, self.buffer = self.buffer.split('```', 1) + while '\n' in before: + line, before = before.split('\n', 1) + await self._flush_markdown(line + '\n') + if before: + # remaining text before the code block without newline + self.buffer = before + self.buffer + if '\n' in self.buffer: + lang_line, rest = self.buffer.split('\n', 1) + self.code_lang = lang_line.strip() + self.buffer = rest + else: + self.code_lang = self.buffer.strip() + self.buffer = '' + self.in_code_block = True + elif '\n' in self.buffer: + line, self.buffer = self.buffer.split('\n', 1) + await self._flush_markdown(line + '\n') + else: + break + + async def _flush_markdown(self, text: str): + if text: + self._erase_last_line() + self.console.print(Markdown(text.rstrip('\n'))) + + async def _flush_code_block(self): + code = '\n'.join(self.code_lines) + lines = len(self.code_lines) + 1 + self.code_lines = [] + lang = self.code_lang if self.code_lang else 'text' + self.code_lang = '' + if code: + self._erase_lines(lines) + syntax = Syntax( + code, + lang, + theme="monokai", + background_color="default", + word_wrap=True, + padding=0, + ) + self.console.print(syntax) + + async def _flush_buffer(self): + if self.in_code_block: + self.code_lines.append(self.buffer) + await self._flush_code_block() + self.in_code_block = False + else: + await self._flush_markdown(self.buffer) + self.buffer = '' + + class ConsolePrinter: def __init__(self, file=sys.stdout):