diff --git a/python/vineyard/core/builder.py b/python/vineyard/core/builder.py index 36cc46bb1..b6e4353cd 100644 --- a/python/vineyard/core/builder.py +++ b/python/vineyard/core/builder.py @@ -156,6 +156,7 @@ def put( builder: Optional[BuilderContext] = None, persist: bool = False, name: Optional[str] = None, + as_async: bool = False, **kwargs ): """Put python value to vineyard. @@ -185,16 +186,22 @@ def put( name: str, optional If given, the name will be automatically associated with the resulted object. Note that only take effect when the object is persisted. + as_async: bool, optional + If true, which means the object will be put to vineyard asynchronously. + Thus we need to use the passed builder. kw: User-specific argument that will be passed to the builder. Returns: ObjectID: The result object id will be returned. """ - if builder is not None: + if builder is not None and not as_async: return builder(client, value, **kwargs) - meta = get_current_builders().run(client, value, **kwargs) + if as_async: + meta = builder.run(client, value, **kwargs) + else: + meta = get_current_builders().run(client, value, **kwargs) # the builders is expected to return an :class:`ObjectMeta`, or an # :class:`Object` (in the `bytes_builder` and `memoryview` builder). diff --git a/python/vineyard/core/client.py b/python/vineyard/core/client.py index 2bb0b1c2f..9328585f0 100644 --- a/python/vineyard/core/client.py +++ b/python/vineyard/core/client.py @@ -42,6 +42,7 @@ from vineyard._C import VineyardException from vineyard._C import _connect from vineyard.core.builder import BuilderContext +from vineyard.core.builder import get_current_builders from vineyard.core.builder import put from vineyard.core.resolver import ResolverContext from vineyard.core.resolver import get @@ -168,6 +169,7 @@ def __init__( session: int = None, username: str = None, password: str = None, + max_workers: int = 8, config: str = None, ): """Connects to the vineyard IPC socket and RPC socket. @@ -211,6 +213,8 @@ def __init__( is enabled. password: Optional, the required password of vineyardd when authentication is enabled. + max_workers: Optional, the maximum number of threads that can be used to + asynchronously put objects to vineyard. Default is 8. config: Optional, can either be a path to a YAML configuration file or a path to a directory containing the default config file `vineyard-config.yaml`. Also, the environment variable @@ -290,6 +294,9 @@ def __init__( except VineyardException: continue + self._max_workers = max_workers + self._put_thread_pool = None + self._spread = False self._compression = True if self._ipc_client is None and self._rpc_client is None: @@ -347,6 +354,13 @@ def rpc_client(self) -> RPCClient: assert self._rpc_client is not None, "RPC client is not available." return self._rpc_client + @property + def put_thread_pool(self) -> ThreadPoolExecutor: + """Lazy initialization of the thread pool for asynchronous put.""" + if self._put_thread_pool is None: + self._put_thread_pool = ThreadPoolExecutor(max_workers=self._max_workers) + return self._put_thread_pool + def has_ipc_client(self): return self._ipc_client is not None @@ -820,17 +834,17 @@ def get( ): return get(self, object_id, name, resolver, fetch, **kwargs) - @_apply_docstring(put) - def put( + def _put_internal( self, value: Any, builder: Optional[BuilderContext] = None, persist: bool = False, name: Optional[str] = None, + as_async: bool = False, **kwargs, ): try: - return put(self, value, builder, persist, name, **kwargs) + return put(self, value, builder, persist, name, as_async, **kwargs) except NotEnoughMemoryException as exec: with envvars( {'VINEYARD_RPC_SKIP_RETRY': '1', 'VINEYARD_IPC_SKIP_RETRY': '1'} @@ -856,7 +870,45 @@ def put( host, port = meta[instance_id]['rpc_endpoint'].split(':') self._rpc_client = _connect(host, port) self.compression = previous_compression_state - return put(self, value, builder, persist, name, **kwargs) + return put(self, value, builder, persist, name, as_async, **kwargs) + + @_apply_docstring(put) + def put( + self, + value: Any, + builder: Optional[BuilderContext] = None, + persist: bool = False, + name: Optional[str] = None, + as_async: bool = False, + **kwargs, + ): + if as_async: + + def _default_callback(future): + try: + result = future.result() + if isinstance(result, ObjectID): + print(f"Successfully put object {result}", flush=True) + elif isinstance(result, ObjectMeta): + print(f"Successfully put object {result.id}", flush=True) + except Exception as e: + print(f"Failed to put object: {e}", flush=True) + + current_builder = builder or get_current_builders() + + thread_pool = self.put_thread_pool + result = thread_pool.submit( + self._put_internal, + value, + current_builder, + persist, + name, + as_async=True, + **kwargs, + ) + result.add_done_callback(_default_callback) + return result + return self._put_internal(value, builder, persist, name, **kwargs) @contextlib.contextmanager def with_compression(self, enabled: bool = True): diff --git a/python/vineyard/core/tests/test_client.py b/python/vineyard/core/tests/test_client.py index ee38eabca..c5a13a9ee 100644 --- a/python/vineyard/core/tests/test_client.py +++ b/python/vineyard/core/tests/test_client.py @@ -19,8 +19,10 @@ import itertools import multiprocessing import random +import time import traceback from concurrent.futures import ThreadPoolExecutor +from threading import Thread import numpy as np @@ -317,3 +319,40 @@ def test_memory_trim(vineyard_client): # there might be some fragmentation overhead assert parse_shared_memory_usage() <= original_memory_usage + 2 * data_kbytes + + +def test_async_put_and_get(vineyard_client): + data = np.ones((100, 100, 16)) + object_nums = 100 + + def producer(vineyard_client): + start_time = time.time() + client = vineyard_client.fork() + for i in range(object_nums): + client.put(data, name="test" + str(i), as_async=True, persist=True) + client.put(data) + end_time = time.time() + print("Producer time: ", end_time - start_time) + + def consumer(vineyard_client): + start_time = time.time() + client = vineyard_client.fork() + for i in range(object_nums): + object_id = client.get_name(name="test" + str(i), wait=True) + client.get(object_id) + end_time = time.time() + print("Consumer time: ", end_time - start_time) + + producer_thread = Thread(target=producer, args=(vineyard_client,)) + consumer_thread = Thread(target=consumer, args=(vineyard_client,)) + + start_time = time.time() + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + + end_time = time.time() + print("Total time: ", end_time - start_time)