Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 72 additions & 14 deletions src/skytable_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,115 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from typing import Optional
from asyncio import StreamReader, StreamWriter
from .query import Query
from .protocol import Protocol
from .response import Response



class Connection:
"""
A database connection to a Skytable instance
A database connection to a Skytable instance.
Handles sending queries and receiving responses using the Skytable protocol.
"""

def __init__(self, reader: StreamReader, writer: StreamWriter) -> None:
self._reader = reader
self._writer = writer
self._protocol = Protocol()
"""
Initialize a new Connection.
:param reader: Asyncio StreamReader for the connection.
:param writer: Asyncio StreamWriter for the connection.
"""
self._reader: StreamReader = reader
self._writer: StreamWriter = writer
self._protocol: Protocol = Protocol()
self._logger = logging.getLogger("skytable.connection")


async def _write_all(self, bytes: bytes):
self._write(bytes)
async def _write_all(self, data: bytes) -> None:
"""
Write all bytes to the connection and flush.
:param data: Bytes to write.
"""
self._write(data)
await self._flush()

def _write(self, bytes: bytes) -> None:
self._writer.write(bytes)

def _write(self, data: bytes) -> None:
"""
Write bytes to the connection buffer.
:param data: Bytes to write.
"""
self._writer.write(data)


def __buffer(self) -> bytes:
"""
Return the current buffer up to the cursor.
"""
return self.buffer[:self._cursor]

async def _flush(self):

async def _flush(self) -> None:
"""
Flush the write buffer.
"""
await self._writer.drain()

async def _read_exact(self, count) -> bytes:

async def _read_exact(self, count: int) -> bytes:
"""
Read exactly `count` bytes from the connection.
:param count: Number of bytes to read.
:return: Bytes read.
"""
return await self._reader.readexactly(count)

async def close(self):

async def close(self) -> None:
"""
Close this connection
Close this connection.
"""
self._writer.close()
await self._writer.wait_closed()


async def run_simple_query(self, query: Query) -> Response:
"""
Send a simple query to the Skytable server and return the response.
:param query: Query object to send.
:return: Response object from the server.
:raises RuntimeError: If no response is received from the server.
"""
query_window_str = str(query._q_window)
total_packet_size = len(query_window_str) + 1 + len(query._buffer)
# write metaframe
metaframe = f"S{str(total_packet_size)}\n{query_window_str}\n"
# Improved: robust SELECT detection (case-insensitive, ignores leading whitespace)
try:
query_str = query._buffer.decode(errors="ignore").lstrip().upper()
except Exception:
query_str = ""
is_select = query_str.startswith("SELECT")
if is_select:
self._logger.debug("Metaframe sent: %r", metaframe.encode())
self._logger.debug("Query buffer sent: %r", query._buffer)
await self._write_all(metaframe.encode())
# write dataframe
await self._write_all(query._buffer)
# read response
while True:
new_block = await self._reader.read(1024)
if not new_block:
if is_select:
self._logger.debug("No data received (connection closed or empty response)")
break
if is_select:
self._logger.debug("Received block: %r", new_block)
self._protocol.push_additional_bytes(new_block)
resp = self._protocol.parse()
if resp:
return resp
# If we exit the loop without returning, raise an error
raise RuntimeError("No response received from server.")
95 changes: 75 additions & 20 deletions src/skytable_py/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,77 @@
ErrorCode, Row, Response



import logging
from typing import Optional

class Protocol:
def __init__(self, buffer=bytes()) -> None:
self._buffer = buffer
self._cursor = 0
"""
Implements the Skytable wire protocol for parsing server responses.
Handles parsing of all supported Skytable types and structures.
"""
def __init__(self, buffer: bytes = bytes()) -> None:
"""
Initialize a new Protocol parser.
:param buffer: Optional initial buffer to parse from.
"""
self._buffer: bytes = buffer
self._cursor: int = 0
self._logger = logging.getLogger("skytable.protocol")

def push_additional_bytes(self, additional_bytes: bytes) -> None:
"""
Add more bytes to the internal buffer for parsing.
:param additional_bytes: Bytes to append.
"""
self._buffer = self._buffer + additional_bytes

def __step(self) -> int:
"""
Advance the cursor by one and return the byte at the new position.
:return: The byte value at the cursor.
"""
ret = self.__buf()[0]
self.__increment_cursor()
return ret

def __decrement(self) -> None:
"""
Move the cursor back by one.
"""
self._cursor -= 1

def __increment_cursor_by(self, by: int) -> None:
"""
Move the cursor forward by `by` positions.
"""
self._cursor += by

def __increment_cursor(self) -> None:
"""
Move the cursor forward by one.
"""
self.__increment_cursor_by(1)

def __buf(self) -> bytes:
"""
Return the buffer from the current cursor position.
"""
return self._buffer[self._cursor:]

def __remaining(self) -> int:
"""
Return the number of bytes remaining in the buffer.
"""
return len(self.__buf())

def __is_eof(self) -> bool:
"""
Return True if the cursor is at the end of the buffer.
"""
return self.__remaining() == 0

def parse_next_int(self, stop_symbol='\n') -> Union[None, int]:
def parse_next_int(self, stop_symbol: str = '\n') -> Optional[int]:
i = 0
integer = 0
stop = False
Expand All @@ -75,23 +114,23 @@ def parse_next_int(self, stop_symbol='\n') -> Union[None, int]:
self.__increment_cursor() # for LF
return integer

def parse_next_string(self) -> Union[None, Value]:
def parse_next_string(self) -> Optional[Value]:
strlen = self.parse_next_int()
if strlen:
if strlen is not None: # Allow zero-length strings (empty strings)
if self.__remaining() >= strlen:
string = self.__buf()[:strlen].decode()
self.__increment_cursor_by(strlen)
return Value(string)

def parse_next_binary(self) -> Union[None, Value]:
def parse_next_binary(self) -> Optional[Value]:
binlen = self.parse_next_int()
if binlen:
if self.__remaining() >= binlen:
blob = self.__buf()[:binlen]
self.__increment_cursor_by(binlen)
return Value(blob)

def parse_boolean(self) -> Union[None, Value]:
def parse_boolean(self) -> Optional[Value]:
# boolean
if self.__is_eof():
self.__decrement() # move back to type symbol
Expand All @@ -102,7 +141,7 @@ def parse_boolean(self) -> Union[None, Value]:
raise ProtocolException("received invalid data")
return Value(True) if byte == 1 else Value(False)

def parse_uint(self, type_symbol: int) -> Union[None, Value]:
def parse_uint(self, type_symbol: int) -> Optional[Value]:
# uint
integer = self.parse_next_int()
if integer:
Expand All @@ -117,7 +156,7 @@ def parse_uint(self, type_symbol: int) -> Union[None, Value]:
else:
self.__decrement() # move back to type symbol

def parse_sint(self, type_symbol: int) -> Union[None, Value]:
def parse_sint(self, type_symbol: int) -> Optional[Value]:
# sint
if self.__is_eof():
self.__decrement() # move back to type symbol
Expand All @@ -144,7 +183,7 @@ def parse_sint(self, type_symbol: int) -> Union[None, Value]:
if is_negative:
self.__decrement() # move back to starting position of this integer

def parse_float(self, type_symbol: int) -> Union[None, Value]:
def parse_float(self, type_symbol: int) -> Optional[Value]:
if self.__is_eof():
self.__decrement() # move back to type symbol
return None
Expand All @@ -168,15 +207,15 @@ def parse_float(self, type_symbol: int) -> Union[None, Value]:
if is_negative:
self.__decrement()

def parse_error_code(self) -> Union[None, ErrorCode]:
def parse_error_code(self) -> Optional[ErrorCode]:
if self.__remaining() < 2:
self.__decrement() # type symbol
else:
a, b = self.__buf()
self.__increment_cursor_by(2)
return ErrorCode(int.from_bytes([a, b], byteorder="little", signed=False))

def parse_list(self) -> Union[None, Value]:
def parse_list(self) -> Optional[Value]:
cursor_start = self._cursor - 1
list_len = self.parse_next_int()
if list_len is None:
Expand All @@ -192,23 +231,39 @@ def parse_list(self) -> Union[None, Value]:
return None
return Value(items)

def parse_row(self) -> Union[None, Row]:
def parse_row(self) -> Optional[Row]:
"""
Parse a single row from the buffer, handling Skytable 0.8.3+ row format.
:return: Row object or None if parsing fails.
"""
cursor_start = self._cursor - 1
column_count = self.parse_next_int()
self._logger.debug("[parse_row] Column count: %r", column_count)
if column_count is None:
self.__decrement() # type symbol
self._logger.debug("[parse_row] Column count is None, returning None")
return None
columns = []
while len(columns) != column_count:
column = self.parse_next_element()
if column:
columns.append(column)
# Skytable 0.8.3+ adds \r separator before each column
if not self.__is_eof() and self.__buf()[0] == ord('\r'):
self._logger.debug("[parse_row] Skipping \\r separator")
self.__increment_cursor() # Skip the \r separator

# In Skytable 0.8.3+, row columns are strings without type symbols
# Format: length\ndata (no type byte)
column_value = self.parse_next_string()
self._logger.debug("[parse_row] Parsed column %d/%d: %r", len(columns) + 1, column_count, column_value)
if column_value:
columns.append(column_value)
else:
self._logger.debug("[parse_row] Column value is None at column %d, returning None", len(columns) + 1)
self._cursor = cursor_start
return None
self._logger.debug("[parse_row] Successfully parsed %d columns, returning Row", len(columns))
return Row(columns)

def parse_rows(self) -> Union[None, List[Row]]:
def parse_rows(self) -> Optional[List[Row]]:
cursor_start = self._cursor - 1
row_count = self.parse_next_int()
rows = []
Expand All @@ -221,12 +276,12 @@ def parse_rows(self) -> Union[None, List[Row]]:
return None
return rows

def parse(self) -> Response:
def parse(self) -> Optional[Response]:
e = self.parse_next_element()
if e:
return Response(e)

def parse_next_element(self) -> Union[None, Value, Empty, ErrorCode]:
def parse_next_element(self) -> Optional[Union[Value, Empty, ErrorCode]]:
if self.__is_eof():
return None
type_symbol = self.__step()
Expand Down