diff --git a/tests/coordination/test_coordination_client.py b/tests/coordination/test_coordination_client.py index 98fb6768..0525c590 100644 --- a/tests/coordination/test_coordination_client.py +++ b/tests/coordination/test_coordination_client.py @@ -1,13 +1,17 @@ +import asyncio + import pytest import ydb -from ydb import aio +from ydb import aio, StatusCode from ydb.coordination import ( NodeConfig, ConsistencyMode, RateLimiterCountersMode, CoordinationClient, + CreateSemaphoreResult, + DescribeLockResult, ) @@ -93,3 +97,96 @@ async def test_coordination_node_lifecycle_async(self, aio_connection): with pytest.raises(ydb.SchemeError): await client.describe_node(node_path) + + async def test_coordination_lock_full_lifecycle(self, aio_connection): + client = aio.CoordinationClient(aio_connection) + + node_path = "/local/test_lock_full_lifecycle" + + try: + await client.delete_node(node_path) + except ydb.SchemeError: + pass + + await client.create_node( + node_path, + NodeConfig( + session_grace_period_millis=1000, + attach_consistency_mode=ConsistencyMode.STRICT, + read_consistency_mode=ConsistencyMode.STRICT, + rate_limiter_counters_mode=RateLimiterCountersMode.UNSET, + self_check_period_millis=0, + ), + ) + + lock = client.lock("test_lock", node_path) + + create_resp: CreateSemaphoreResult = await lock.create(init_limit=1, init_data=b"init-data") + assert create_resp.status == StatusCode.SUCCESS + + describe_resp: DescribeLockResult = await lock.describe() + assert describe_resp.status == StatusCode.SUCCESS + assert describe_resp.name == "test_lock" + assert describe_resp.data == b"init-data" + assert describe_resp.count == 0 + assert describe_resp.ephemeral is False + assert list(describe_resp.owners) == [] + assert list(describe_resp.waiters) == [] + + update_resp = await lock.update(new_data=b"updated-data") + assert update_resp.status == StatusCode.SUCCESS + + describe_resp2: DescribeLockResult = await lock.describe() + assert describe_resp2.status == StatusCode.SUCCESS + assert describe_resp2.name == "test_lock" + assert describe_resp2.data == b"updated-data" + assert describe_resp2.count == 0 + assert describe_resp2.ephemeral is False + assert list(describe_resp2.owners) == [] + assert list(describe_resp2.waiters) == [] + + lock2_started = asyncio.Event() + lock2_acquired = asyncio.Event() + + async def second_lock_task(): + lock2_started.set() + async with client.lock("test_lock", node_path): + lock2_acquired.set() + await asyncio.sleep(0.5) + + async with client.lock("test_lock", node_path) as lock1: + assert lock1._stream is not None + assert lock1._stream.session_id is not None + + resp: DescribeLockResult = await lock1.describe() + assert resp.status == StatusCode.SUCCESS + assert resp.name == "test_lock" + assert resp.data == b"updated-data" + assert resp.count == 1 + assert resp.ephemeral is False + assert len(list(resp.owners)) == 1 + assert list(resp.waiters) == [] + + t2 = asyncio.create_task(second_lock_task()) + await lock2_started.wait() + + await asyncio.sleep(0.5) + + assert lock1._stream is not None + + await asyncio.wait_for(lock2_acquired.wait(), timeout=5) + await asyncio.wait_for(t2, timeout=5) + + async with client.lock("test_lock", node_path) as lock3: + assert lock3._stream is not None + assert lock3._stream.session_id is not None + + resp3: DescribeLockResult = await lock3.describe() + assert resp3.status == StatusCode.SUCCESS + assert resp3.count == 1 + + delete_resp = await lock.delete() + assert delete_resp.status == StatusCode.SUCCESS + + describe_after_delete: DescribeLockResult = await lock.describe() + assert describe_after_delete.status == StatusCode.NOT_FOUND diff --git a/ydb/_apis.py b/ydb/_apis.py index 97f64b90..595550b2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -143,9 +143,9 @@ class QueryService(object): class CoordinationService(object): Stub = ydb_coordination_v1_pb2_grpc.CoordinationServiceStub - - Session = "Session" CreateNode = "CreateNode" AlterNode = "AlterNode" DropNode = "DropNode" DescribeNode = "DescribeNode" + SessionRequest = "SessionRequest" + Session = "Session" diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination.py b/ydb/_grpc/grpcwrapper/ydb_coordination.py index 176e4e02..e5d6808e 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination.py @@ -1,7 +1,6 @@ import typing from dataclasses import dataclass -from .ydb_coordination_public_types import NodeConfig if typing.TYPE_CHECKING: from ..v4.protos import ydb_coordination_pb2 @@ -14,7 +13,7 @@ @dataclass class CreateNodeRequest(IToProto): path: str - config: typing.Optional[NodeConfig] + config: typing.Any def to_proto(self) -> ydb_coordination_pb2.CreateNodeRequest: cfg_proto = self.config.to_proto() if self.config else None @@ -27,7 +26,7 @@ def to_proto(self) -> ydb_coordination_pb2.CreateNodeRequest: @dataclass class AlterNodeRequest(IToProto): path: str - config: NodeConfig + config: typing.Any def to_proto(self) -> ydb_coordination_pb2.AlterNodeRequest: cfg_proto = self.config.to_proto() if self.config else None @@ -55,3 +54,184 @@ def to_proto(self) -> ydb_coordination_pb2.DropNodeRequest: return ydb_coordination_pb2.DropNodeRequest( path=self.path, ) + + +@dataclass +class SessionStart(IToProto): + path: str + timeout_millis: int + description: str = "" + session_id: int = 0 + seq_no: int = 0 + protection_key: bytes = b"" + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + session_start=ydb_coordination_pb2.SessionRequest.SessionStart( + path=self.path, + session_id=self.session_id, + timeout_millis=self.timeout_millis, + description=self.description, + seq_no=self.seq_no, + protection_key=self.protection_key, + ) + ) + + +@dataclass +class SessionStop(IToProto): + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest(session_stop=ydb_coordination_pb2.SessionRequest.SessionStop()) + + +@dataclass +class Ping(IToProto): + opaque: int = 0 + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + ping=ydb_coordination_pb2.SessionRequest.PingPong(opaque=self.opaque) + ) + + +@dataclass +class CreateSemaphore(IToProto): + name: str + req_id: int + limit: int + data: bytes = b"" + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + create_semaphore=ydb_coordination_pb2.SessionRequest.CreateSemaphore( + req_id=self.req_id, name=self.name, limit=self.limit, data=self.data + ) + ) + + +@dataclass +class UpdateSemaphore(IToProto): + name: str + req_id: int + data: bytes + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + update_semaphore=ydb_coordination_pb2.SessionRequest.UpdateSemaphore( + req_id=self.req_id, name=self.name, data=self.data + ) + ) + + +@dataclass +class DeleteSemaphore(IToProto): + name: str + req_id: int + force: bool = False + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + delete_semaphore=ydb_coordination_pb2.SessionRequest.DeleteSemaphore( + req_id=self.req_id, name=self.name, force=self.force + ) + ) + + +@dataclass +class AcquireSemaphore(IToProto): + name: str + req_id: int + count: int = 1 + timeout_millis: int = 0 + data: bytes = b"" + ephemeral: bool = False + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + acquire_semaphore=ydb_coordination_pb2.SessionRequest.AcquireSemaphore( + req_id=self.req_id, + name=self.name, + timeout_millis=self.timeout_millis, + count=self.count, + data=self.data, + ephemeral=self.ephemeral, + ) + ) + + +@dataclass +class ReleaseSemaphore(IToProto): + name: str + req_id: int + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + release_semaphore=ydb_coordination_pb2.SessionRequest.ReleaseSemaphore(req_id=self.req_id, name=self.name) + ) + + +@dataclass +class DescribeSemaphore(IToProto): + include_owners: bool + include_waiters: bool + name: str + req_id: int + watch_data: bool + watch_owners: bool + + def to_proto(self) -> ydb_coordination_pb2.SessionRequest: + return ydb_coordination_pb2.SessionRequest( + describe_semaphore=ydb_coordination_pb2.SessionRequest.DescribeSemaphore( + include_owners=self.include_owners, + include_waiters=self.include_waiters, + name=self.name, + req_id=self.req_id, + watch_data=self.watch_data, + watch_owners=self.watch_owners, + ) + ) + + +@dataclass +class FromServer: + raw: ydb_coordination_pb2.SessionResponse + + @staticmethod + def from_proto(resp: ydb_coordination_pb2.SessionResponse) -> "FromServer": + return FromServer(raw=resp) + + def __getattr__(self, name: str): + return getattr(self.raw, name) + + @property + def session_started(self) -> typing.Optional[ydb_coordination_pb2.SessionResponse.SessionStarted]: + s = self.raw.session_started + return s if s.session_id else None + + @property + def opaque(self) -> typing.Optional[int]: + if self.raw.HasField("ping"): + return self.raw.ping.opaque + return None + + @property + def acquire_semaphore_result(self) -> typing.Optional[ydb_coordination_pb2.SessionResponse.AcquireSemaphoreResult]: + return self.raw.acquire_semaphore_result if self.raw.HasField("acquire_semaphore_result") else None + + @property + def create_semaphore_result(self) -> typing.Optional[ydb_coordination_pb2.SessionResponse.CreateSemaphoreResult]: + return self.raw.create_semaphore_result if self.raw.HasField("create_semaphore_result") else None + + @property + def delete_semaphore_result(self) -> typing.Optional[ydb_coordination_pb2.SessionResponse.DeleteSemaphoreResult]: + return self.raw.delete_semaphore_result if self.raw.HasField("delete_semaphore_result") else None + + @property + def update_semaphore_result(self) -> typing.Optional[ydb_coordination_pb2.SessionResponse.UpdateSemaphoreResult]: + return self.raw.update_semaphore_result if self.raw.HasField("update_semaphore_result") else None + + @property + def describe_semaphore_result( + self, + ) -> typing.Optional[ydb_coordination_pb2.SessionResponse.DescribeSemaphoreResult]: + return self.raw.describe_semaphore_result if self.raw.HasField("describe_semaphore_result") else None diff --git a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py index a3580974..1112cd4b 100644 --- a/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py +++ b/ydb/_grpc/grpcwrapper/ydb_coordination_public_types.py @@ -2,7 +2,6 @@ from enum import IntEnum import typing - if typing.TYPE_CHECKING: from ..v4.protos import ydb_coordination_pb2 else: @@ -55,3 +54,60 @@ def from_proto(msg: ydb_coordination_pb2.DescribeNodeResponse) -> "NodeConfig": result = ydb_coordination_pb2.DescribeNodeResult() msg.operation.result.Unpack(result) return NodeConfig.from_proto(result.config) + + +@dataclass +class AcquireSemaphoreResult: + req_id: int + acquired: bool + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.AcquireSemaphoreResult) -> "AcquireSemaphoreResult": + return AcquireSemaphoreResult( + req_id=msg.req_id, + acquired=msg.acquired, + status=msg.status, + ) + + +@dataclass +class CreateSemaphoreResult: + req_id: int + status: int + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.CreateSemaphoreResult) -> "CreateSemaphoreResult": + return CreateSemaphoreResult( + req_id=msg.req_id, + status=msg.status, + ) + + +@dataclass +class DescribeLockResult: + req_id: int + status: int + watch_added: bool + count: int + data: bytes + ephemeral: bool + limit: int + name: str + owners: list + waiters: list + + @staticmethod + def from_proto(msg: ydb_coordination_pb2.SessionResponse.DescribeSemaphoreResult) -> "DescribeLockResult": + return DescribeLockResult( + req_id=msg.req_id, + status=msg.status, + watch_added=msg.watch_added, + count=msg.semaphore_description.count, + data=msg.semaphore_description.data, + ephemeral=msg.semaphore_description.ephemeral, + limit=msg.semaphore_description.limit, + name=msg.semaphore_description.name, + owners=msg.semaphore_description.owners, + waiters=msg.semaphore_description.waiters, + ) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index d38d9e73..9747666f 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,4 +1,4 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa from .query import QuerySessionPool, QuerySession, QueryTxContext # noqa -from .coordination_client import CoordinationClient # noqa +from .coordination import CoordinationClient # noqa diff --git a/ydb/aio/coordination/__init__.py b/ydb/aio/coordination/__init__.py new file mode 100644 index 00000000..f6d48237 --- /dev/null +++ b/ydb/aio/coordination/__init__.py @@ -0,0 +1,5 @@ +__all__ = [ + "CoordinationClient", +] + +from .client import CoordinationClient diff --git a/ydb/aio/coordination/client.py b/ydb/aio/coordination/client.py new file mode 100644 index 00000000..90d86df4 --- /dev/null +++ b/ydb/aio/coordination/client.py @@ -0,0 +1,45 @@ +from typing import Optional + +from ydb._grpc.grpcwrapper.ydb_coordination import ( + CreateNodeRequest, + DescribeNodeRequest, + AlterNodeRequest, + DropNodeRequest, +) +from ydb._grpc.grpcwrapper.ydb_coordination_public_types import NodeConfig +from ydb.coordination.base_coordination_client import BaseCoordinationClient + +from ydb.aio.coordination.lock import CoordinationLock + + +class CoordinationClient(BaseCoordinationClient): + def __init__(self, driver): + super().__init__(driver) + self._driver = driver + + async def create_node(self, path: str, config: Optional[NodeConfig] = None, settings=None): + return await self._call_create( + CreateNodeRequest(path=path, config=config).to_proto(), + settings=settings, + ) + + async def describe_node(self, path: str, settings=None) -> NodeConfig: + return await self._call_describe( + DescribeNodeRequest(path=path).to_proto(), + settings=settings, + ) + + async def alter_node(self, path: str, new_config: NodeConfig, settings=None): + return await self._call_alter( + AlterNodeRequest(path=path, config=new_config).to_proto(), + settings=settings, + ) + + async def delete_node(self, path: str, settings=None): + return await self._call_delete( + DropNodeRequest(path=path).to_proto(), + settings=settings, + ) + + def lock(self, lock_name: str, node_path: str): + return CoordinationLock(self, lock_name, node_path=node_path) diff --git a/ydb/aio/coordination/lock.py b/ydb/aio/coordination/lock.py new file mode 100644 index 00000000..ec1f6998 --- /dev/null +++ b/ydb/aio/coordination/lock.py @@ -0,0 +1,207 @@ +import asyncio +from typing import Optional + +from ydb import issues +from ydb._grpc.grpcwrapper.ydb_coordination import ( + AcquireSemaphore, + ReleaseSemaphore, + UpdateSemaphore, + DescribeSemaphore, + CreateSemaphore, + DeleteSemaphore, + FromServer, +) +from ydb._grpc.grpcwrapper.ydb_coordination_public_types import CreateSemaphoreResult, DescribeLockResult +from ydb.aio.coordination.stream import CoordinationStream +from ydb.aio.coordination.reconnector import CoordinationReconnector + + +class CoordinationLock: + def __init__( + self, + client, + name: str, + node_path: Optional[str] = None, + count: int = 1, + timeout_millis: int = 30000, + ): + self._client = client + self._driver = client._driver + self._name = name + self._node_path = node_path + + self._req_id: Optional[int] = None + self._count: int = count + self._timeout_millis: int = timeout_millis + self._next_req_id: int = 1 + + self._request_queue: asyncio.Queue = asyncio.Queue() + self._stream: Optional[CoordinationStream] = None + + self._reconnector = CoordinationReconnector( + driver=self._driver, + request_queue=self._request_queue, + node_path=self._node_path, + timeout_millis=self._timeout_millis, + ) + + self._wait_timeout: float = self._timeout_millis / 1000.0 + + def next_req_id(self) -> int: + r = self._next_req_id + self._next_req_id += 1 + return r + + async def send(self, req): + if self._stream is None: + raise issues.Error("Stream is not started yet") + await self._stream.send(req) + + async def _ensure_session(self): + if self._stream is not None and self._stream.session_id is not None: + return + + if not self._node_path: + raise issues.Error("node_path is not set for CoordinationLock") + + self._reconnector.start() + await self._reconnector.wait_ready() + + self._stream = self._reconnector.get_stream() + + async def _wait_for_response(self, req_id: int, *, kind: str): + try: + while True: + resp = await asyncio.wait_for( + self._stream._incoming_queue.get(), + timeout=self._wait_timeout, + ) + fs = FromServer.from_proto(resp) + + if kind == "acquire": + r = fs.acquire_semaphore_result + elif kind == "describe": + r = fs.describe_semaphore_result + elif kind == "create": + r = fs.create_semaphore_result + elif kind == "update": + r = fs.update_semaphore_result + elif kind == "delete": + r = fs.delete_semaphore_result + else: + r = None + + if r and r.req_id == req_id: + return r + + except asyncio.TimeoutError: + action = { + "acquire": "acquisition", + "describe": "describe", + "update": "update", + "delete": "delete", + "create": "create", + }.get(kind, "operation") + + raise issues.Error(f"Timeout waiting for lock {self._name} {action}") + + async def __aenter__(self): + await self._ensure_session() + + req_id = self.next_req_id() + self._req_id = req_id + + req = AcquireSemaphore( + req_id=req_id, + name=self._name, + count=self._count, + ephemeral=False, + timeout_millis=self._timeout_millis, + ).to_proto() + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="acquire") + if resp.acquired: + return self + else: + raise issues.Error(f"Failed to acquire lock: {resp.issues}") + + async def __aexit__(self, exc_type, exc, tb): + if self._req_id is not None: + try: + req = ReleaseSemaphore( + req_id=self._req_id, + name=self._name, + ).to_proto() + await self.send(req) + except issues.Error: + pass + + await self._reconnector.stop() + self._stream = None + self._node_path = None + self._req_id = None + + async def acquire(self): + return await self.__aenter__() + + async def release(self): + await self.__aexit__(None, None, None) + + async def create(self, init_limit, init_data): + await self._ensure_session() + + req_id = self.next_req_id() + + req = CreateSemaphore(req_id=req_id, name=self._name, limit=init_limit, data=init_data).to_proto() + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="create") + return CreateSemaphoreResult.from_proto(resp) + + async def delete(self): + await self._ensure_session() + + req_id = self.next_req_id() + + req = DeleteSemaphore( + req_id=req_id, + name=self._name, + ).to_proto() + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="delete") + return resp + + async def describe(self): + await self._ensure_session() + + req_id = self.next_req_id() + + req = DescribeSemaphore( + req_id=req_id, + name=self._name, + include_owners=True, + include_waiters=True, + watch_data=False, + watch_owners=False, + ).to_proto() + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="describe") + return DescribeLockResult.from_proto(resp) + + async def update(self, new_data): + await self._ensure_session() + + req_id = self.next_req_id() + req = UpdateSemaphore(req_id=req_id, name=self._name, data=new_data).to_proto() + + await self.send(req) + + resp = await self._wait_for_response(req_id, kind="update") + return resp diff --git a/ydb/aio/coordination/reconnector.py b/ydb/aio/coordination/reconnector.py new file mode 100644 index 00000000..87ae3893 --- /dev/null +++ b/ydb/aio/coordination/reconnector.py @@ -0,0 +1,104 @@ +import asyncio +import contextlib +from typing import Optional + +from ydb.aio.coordination.stream import CoordinationStream + + +class CoordinationReconnector: + def __init__( + self, + driver, + request_queue: asyncio.Queue, + node_path: str, + timeout_millis: int, + ): + self._driver = driver + self._request_queue = request_queue + self._node_path = node_path + self._timeout_millis = timeout_millis + + self._task: Optional[asyncio.Task] = None + self._stream: Optional[CoordinationStream] = None + + self._ready = asyncio.Event() + self._stopped = False + + self._first_error: asyncio.Future = asyncio.get_running_loop().create_future() + self._state_changed = asyncio.Event() + + def start(self): + if self._stopped: + return + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._connection_loop()) + + async def stop(self): + self._stopped = True + + if self._task: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None + + if self._stream: + await self._stream.close() + self._stream = None + + self._ready.clear() + + async def wait_ready(self): + await self._ready.wait() + + def get_stream(self) -> CoordinationStream: + if self._stream is None or self._stream.session_id is None: + raise RuntimeError("Coordination stream is not ready") + return self._stream + + async def _connection_loop(self): + attempt = 0 + backoff = 0.1 + + while not self._stopped: + try: + stream = CoordinationStream( + self._driver, + self._request_queue, + ) + + await stream.start_session( + self._node_path, + self._timeout_millis, + ) + + self._stream = stream + self._ready.set() + + await asyncio.wait( + stream._background_tasks, + return_when=asyncio.FIRST_EXCEPTION, + ) + + except asyncio.CancelledError: + break + + except Exception as exc: + self._ready.clear() + self._stream = None + + if not self._first_error.done(): + self._first_error.set_result(exc) + self._state_changed.set() + + if self._stopped: + break + + await asyncio.sleep(backoff) + attempt += 1 + backoff = min(backoff * 2, 3.0) + + finally: + if self._stream: + await self._stream.close() + self._stream = None diff --git a/ydb/aio/coordination/stream.py b/ydb/aio/coordination/stream.py new file mode 100644 index 00000000..4554ada4 --- /dev/null +++ b/ydb/aio/coordination/stream.py @@ -0,0 +1,121 @@ +import asyncio +import contextlib +from typing import Set, Optional + +import ydb +from ydb import issues, _apis +from ydb._grpc.grpcwrapper.ydb_coordination import FromServer, Ping, SessionStart + + +class CoordinationStream: + def __init__(self, driver: "ydb.aio.Driver", request_queue: asyncio.Queue): + self._driver = driver + self._request_queue = request_queue + self._stream = None + self._closed: bool = False + self._background_tasks: Set[asyncio.Task] = set() + self._incoming_queue: asyncio.Queue = asyncio.Queue() + self._state_changed = asyncio.Event() + self._first_error: asyncio.Future = asyncio.get_running_loop().create_future() + self.session_id: Optional[int] = None + self._started: bool = False + + async def start_session(self, path: str, timeout_millis: int): + if self._started: + raise issues.Error("CoordinationStream already started") + + await self.send( + SessionStart( + path=path, + session_id=0, + timeout_millis=timeout_millis, + ).to_proto() + ) + + await self._start_internal() + + async def _start_internal(self): + if self._started: + raise issues.Error("CoordinationStream already started") + self._started = True + + async def request_gen(): + while not self._closed: + req = await self._request_queue.get() + yield req + + self._stream = await self._driver( + request_gen(), + _apis.CoordinationService.Stub, + _apis.CoordinationService.Session, + ) + + try: + async for resp in self._stream: + fs = FromServer.from_proto(resp) + if fs.session_started: + self.session_id = fs.session_started + break + except Exception as exc: + self._set_first_error(exc) + raise + + self._background_tasks.add(asyncio.create_task(self._reader_loop())) + + async def _reader_loop(self): + try: + async for resp in self._stream: + ping_opaque = FromServer.from_proto(resp).opaque + if ping_opaque: + await self.send(Ping(ping_opaque).to_proto()) + else: + self._incoming_queue.put_nowait(resp) + self._state_changed.set() + except Exception as exc: + self._set_first_error(exc) + + async def send(self, req): + self._check_error() + if self._closed: + raise issues.Error("Stream closed") + await self._request_queue.put(req) + + def receive_nowait(self): + self._check_error() + if self._incoming_queue.empty(): + return None + return self._incoming_queue.get_nowait() + + async def close(self): + if self._closed: + return + self._closed = True + + for task in self._background_tasks: + task.cancel() + + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*self._background_tasks) + + if self._stream: + try: + self._stream.close() + except Exception: + pass + + self.session_id = None + self._state_changed.set() + + def _set_first_error(self, exc: Exception): + if not self._first_error.done(): + self._first_error.set_result(exc) + self._state_changed.set() + + def _get_first_error(self): + if self._first_error.done(): + return self._first_error.result() + + def _check_error(self): + err = self._get_first_error() + if err: + raise err diff --git a/ydb/coordination/__init__.py b/ydb/coordination/__init__.py index 55834e89..1e280ee7 100644 --- a/ydb/coordination/__init__.py +++ b/ydb/coordination/__init__.py @@ -5,6 +5,16 @@ ConsistencyMode, RateLimiterCountersMode, DescribeResult, + CreateSemaphoreResult, + DescribeLockResult, ) -__all__ = ["CoordinationClient", "NodeConfig", "ConsistencyMode", "RateLimiterCountersMode", "DescribeResult"] +__all__ = [ + "CoordinationClient", + "NodeConfig", + "ConsistencyMode", + "RateLimiterCountersMode", + "DescribeResult", + "CreateSemaphoreResult", + "DescribeLockResult", +]