diff --git a/src/skytable_py/connection.py b/src/skytable_py/connection.py index 839e03c..c00f059 100644 --- a/src/skytable_py/connection.py +++ b/src/skytable_py/connection.py @@ -13,57 +13,115 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import logging +from typing import Optional from asyncio import StreamReader, StreamWriter from .query import Query from .protocol import Protocol from .response import Response + class Connection: """ - A database connection to a Skytable instance + A database connection to a Skytable instance. + Handles sending queries and receiving responses using the Skytable protocol. """ def __init__(self, reader: StreamReader, writer: StreamWriter) -> None: - self._reader = reader - self._writer = writer - self._protocol = Protocol() + """ + Initialize a new Connection. + :param reader: Asyncio StreamReader for the connection. + :param writer: Asyncio StreamWriter for the connection. + """ + self._reader: StreamReader = reader + self._writer: StreamWriter = writer + self._protocol: Protocol = Protocol() + self._logger = logging.getLogger("skytable.connection") + - async def _write_all(self, bytes: bytes): - self._write(bytes) + async def _write_all(self, data: bytes) -> None: + """ + Write all bytes to the connection and flush. + :param data: Bytes to write. + """ + self._write(data) await self._flush() - def _write(self, bytes: bytes) -> None: - self._writer.write(bytes) + + def _write(self, data: bytes) -> None: + """ + Write bytes to the connection buffer. + :param data: Bytes to write. + """ + self._writer.write(data) + def __buffer(self) -> bytes: + """ + Return the current buffer up to the cursor. + """ return self.buffer[:self._cursor] - async def _flush(self): + + async def _flush(self) -> None: + """ + Flush the write buffer. + """ await self._writer.drain() - async def _read_exact(self, count) -> bytes: + + async def _read_exact(self, count: int) -> bytes: + """ + Read exactly `count` bytes from the connection. + :param count: Number of bytes to read. + :return: Bytes read. + """ return await self._reader.readexactly(count) - async def close(self): + + async def close(self) -> None: """ - Close this connection + Close this connection. """ self._writer.close() await self._writer.wait_closed() + async def run_simple_query(self, query: Query) -> Response: + """ + Send a simple query to the Skytable server and return the response. + :param query: Query object to send. + :return: Response object from the server. + :raises RuntimeError: If no response is received from the server. + """ query_window_str = str(query._q_window) total_packet_size = len(query_window_str) + 1 + len(query._buffer) - # write metaframe metaframe = f"S{str(total_packet_size)}\n{query_window_str}\n" + # Improved: robust SELECT detection (case-insensitive, ignores leading whitespace) + try: + query_str = query._buffer.decode(errors="ignore").lstrip().upper() + except Exception: + query_str = "" + is_select = query_str.startswith("SELECT") + if is_select: + self._logger.debug("Metaframe sent: %r", metaframe.encode()) + self._logger.debug("Query buffer sent: %r", query._buffer) await self._write_all(metaframe.encode()) - # write dataframe await self._write_all(query._buffer) # read response while True: new_block = await self._reader.read(1024) + if not new_block: + if is_select: + self._logger.debug("No data received (connection closed or empty response)") + break + if is_select: + self._logger.debug("Received block: %r", new_block) self._protocol.push_additional_bytes(new_block) resp = self._protocol.parse() if resp: return resp + # If we exit the loop without returning, raise an error + raise RuntimeError("No response received from server.") diff --git a/src/skytable_py/protocol.py b/src/skytable_py/protocol.py index 596a5fc..a410a9f 100644 --- a/src/skytable_py/protocol.py +++ b/src/skytable_py/protocol.py @@ -19,38 +19,77 @@ ErrorCode, Row, Response + +import logging +from typing import Optional + class Protocol: - def __init__(self, buffer=bytes()) -> None: - self._buffer = buffer - self._cursor = 0 + """ + Implements the Skytable wire protocol for parsing server responses. + Handles parsing of all supported Skytable types and structures. + """ + def __init__(self, buffer: bytes = bytes()) -> None: + """ + Initialize a new Protocol parser. + :param buffer: Optional initial buffer to parse from. + """ + self._buffer: bytes = buffer + self._cursor: int = 0 + self._logger = logging.getLogger("skytable.protocol") def push_additional_bytes(self, additional_bytes: bytes) -> None: + """ + Add more bytes to the internal buffer for parsing. + :param additional_bytes: Bytes to append. + """ self._buffer = self._buffer + additional_bytes def __step(self) -> int: + """ + Advance the cursor by one and return the byte at the new position. + :return: The byte value at the cursor. + """ ret = self.__buf()[0] self.__increment_cursor() return ret def __decrement(self) -> None: + """ + Move the cursor back by one. + """ self._cursor -= 1 def __increment_cursor_by(self, by: int) -> None: + """ + Move the cursor forward by `by` positions. + """ self._cursor += by def __increment_cursor(self) -> None: + """ + Move the cursor forward by one. + """ self.__increment_cursor_by(1) def __buf(self) -> bytes: + """ + Return the buffer from the current cursor position. + """ return self._buffer[self._cursor:] def __remaining(self) -> int: + """ + Return the number of bytes remaining in the buffer. + """ return len(self.__buf()) def __is_eof(self) -> bool: + """ + Return True if the cursor is at the end of the buffer. + """ return self.__remaining() == 0 - def parse_next_int(self, stop_symbol='\n') -> Union[None, int]: + def parse_next_int(self, stop_symbol: str = '\n') -> Optional[int]: i = 0 integer = 0 stop = False @@ -75,15 +114,15 @@ def parse_next_int(self, stop_symbol='\n') -> Union[None, int]: self.__increment_cursor() # for LF return integer - def parse_next_string(self) -> Union[None, Value]: + def parse_next_string(self) -> Optional[Value]: strlen = self.parse_next_int() - if strlen: + if strlen is not None: # Allow zero-length strings (empty strings) if self.__remaining() >= strlen: string = self.__buf()[:strlen].decode() self.__increment_cursor_by(strlen) return Value(string) - def parse_next_binary(self) -> Union[None, Value]: + def parse_next_binary(self) -> Optional[Value]: binlen = self.parse_next_int() if binlen: if self.__remaining() >= binlen: @@ -91,7 +130,7 @@ def parse_next_binary(self) -> Union[None, Value]: self.__increment_cursor_by(binlen) return Value(blob) - def parse_boolean(self) -> Union[None, Value]: + def parse_boolean(self) -> Optional[Value]: # boolean if self.__is_eof(): self.__decrement() # move back to type symbol @@ -102,7 +141,7 @@ def parse_boolean(self) -> Union[None, Value]: raise ProtocolException("received invalid data") return Value(True) if byte == 1 else Value(False) - def parse_uint(self, type_symbol: int) -> Union[None, Value]: + def parse_uint(self, type_symbol: int) -> Optional[Value]: # uint integer = self.parse_next_int() if integer: @@ -117,7 +156,7 @@ def parse_uint(self, type_symbol: int) -> Union[None, Value]: else: self.__decrement() # move back to type symbol - def parse_sint(self, type_symbol: int) -> Union[None, Value]: + def parse_sint(self, type_symbol: int) -> Optional[Value]: # sint if self.__is_eof(): self.__decrement() # move back to type symbol @@ -144,7 +183,7 @@ def parse_sint(self, type_symbol: int) -> Union[None, Value]: if is_negative: self.__decrement() # move back to starting position of this integer - def parse_float(self, type_symbol: int) -> Union[None, Value]: + def parse_float(self, type_symbol: int) -> Optional[Value]: if self.__is_eof(): self.__decrement() # move back to type symbol return None @@ -168,7 +207,7 @@ def parse_float(self, type_symbol: int) -> Union[None, Value]: if is_negative: self.__decrement() - def parse_error_code(self) -> Union[None, ErrorCode]: + def parse_error_code(self) -> Optional[ErrorCode]: if self.__remaining() < 2: self.__decrement() # type symbol else: @@ -176,7 +215,7 @@ def parse_error_code(self) -> Union[None, ErrorCode]: self.__increment_cursor_by(2) return ErrorCode(int.from_bytes([a, b], byteorder="little", signed=False)) - def parse_list(self) -> Union[None, Value]: + def parse_list(self) -> Optional[Value]: cursor_start = self._cursor - 1 list_len = self.parse_next_int() if list_len is None: @@ -192,23 +231,39 @@ def parse_list(self) -> Union[None, Value]: return None return Value(items) - def parse_row(self) -> Union[None, Row]: + def parse_row(self) -> Optional[Row]: + """ + Parse a single row from the buffer, handling Skytable 0.8.3+ row format. + :return: Row object or None if parsing fails. + """ cursor_start = self._cursor - 1 column_count = self.parse_next_int() + self._logger.debug("[parse_row] Column count: %r", column_count) if column_count is None: self.__decrement() # type symbol + self._logger.debug("[parse_row] Column count is None, returning None") return None columns = [] while len(columns) != column_count: - column = self.parse_next_element() - if column: - columns.append(column) + # Skytable 0.8.3+ adds \r separator before each column + if not self.__is_eof() and self.__buf()[0] == ord('\r'): + self._logger.debug("[parse_row] Skipping \\r separator") + self.__increment_cursor() # Skip the \r separator + + # In Skytable 0.8.3+, row columns are strings without type symbols + # Format: length\ndata (no type byte) + column_value = self.parse_next_string() + self._logger.debug("[parse_row] Parsed column %d/%d: %r", len(columns) + 1, column_count, column_value) + if column_value: + columns.append(column_value) else: + self._logger.debug("[parse_row] Column value is None at column %d, returning None", len(columns) + 1) self._cursor = cursor_start return None + self._logger.debug("[parse_row] Successfully parsed %d columns, returning Row", len(columns)) return Row(columns) - def parse_rows(self) -> Union[None, List[Row]]: + def parse_rows(self) -> Optional[List[Row]]: cursor_start = self._cursor - 1 row_count = self.parse_next_int() rows = [] @@ -221,12 +276,12 @@ def parse_rows(self) -> Union[None, List[Row]]: return None return rows - def parse(self) -> Response: + def parse(self) -> Optional[Response]: e = self.parse_next_element() if e: return Response(e) - def parse_next_element(self) -> Union[None, Value, Empty, ErrorCode]: + def parse_next_element(self) -> Optional[Union[Value, Empty, ErrorCode]]: if self.__is_eof(): return None type_symbol = self.__step()