diff --git a/vectordb_bench/backend/runner/rate_runner.py b/vectordb_bench/backend/runner/rate_runner.py index 077c449d4..111597c85 100644 --- a/vectordb_bench/backend/runner/rate_runner.py +++ b/vectordb_bench/backend/runner/rate_runner.py @@ -3,6 +3,7 @@ import multiprocessing as mp import time from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy from vectordb_bench import config from vectordb_bench.backend.clients import api @@ -33,17 +34,23 @@ def __init__( self.executing_futures = [] self.sig_idx = 0 - def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str], retry_idx: int = 0): - _, error = db.insert_embeddings(emb, metadata) - if error is not None: - log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}") - retry_idx += 1 - if retry_idx <= config.MAX_INSERT_RETRY: - time.sleep(retry_idx) - self.send_insert_task(db, emb=emb, metadata=metadata, retry_idx=retry_idx) - else: - msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times" - raise RuntimeError(msg) from None + def send_insert_task(self, emb: list[list[float]], metadata: list[str]): + + def _insert_embeddings(db: api.VectorDB, emb: list[list[float]], metadata: list[str], retry_idx: int = 0): + _, error = db.insert_embeddings(emb, metadata) + if error is not None: + log.warning(f"Insert Failed, try_idx={retry_idx}, Exception: {error}") + retry_idx += 1 + if retry_idx <= config.MAX_INSERT_RETRY: + time.sleep(retry_idx) + _insert_embeddings(db, emb=emb, metadata=metadata, retry_idx=retry_idx) + else: + msg = f"Insert failed and retried more than {config.MAX_INSERT_RETRY} times" + raise RuntimeError(msg) from None + + db_copy = deepcopy(self.db) + with db_copy.init(): + _insert_embeddings(db_copy, emb, metadata, retry_idx=0) @time_it def run_with_rate(self, q: mp.Queue): @@ -54,7 +61,7 @@ def submit_by_rate() -> bool: rate = self.batch_rate for data in self.dataset: emb, metadata = get_data(data, self.normalize) - self.executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata)) + self.executing_futures.append(executor.submit(self.send_insert_task, emb, metadata)) rate -= 1 if rate == 0: @@ -89,35 +96,35 @@ def check_and_send_signal(wait_interval: float, finished: bool = False): raise e from None time_per_batch = config.TIME_PER_BATCH - with self.db.init(): - start_time = time.perf_counter() - round_idx = 0 - while True: - if len(self.executing_futures) > 200: - log.warning("Skip data insertion this round. There are 200+ unfinished insertion tasks.") - else: - finished, elapsed_time = submit_by_rate() - if finished is True: - log.info( - f"End of dataset, left unfinished={len(self.executing_futures)}, num_round={round_idx}" - ) - break - if elapsed_time >= 1.5: - log.warning( - f"Submit insert tasks took {elapsed_time}s, expected 1s, " - f"indicating potential resource limitations on the client machine.", - ) - - check_and_send_signal(wait_interval=0.001, finished=False) - dur = time.perf_counter() - start_time - round_idx * time_per_batch - if dur < time_per_batch: - time.sleep(time_per_batch - dur) - round_idx += 1 - - # wait for all tasks in executing_futures to complete - while len(self.executing_futures) > 0: - check_and_send_signal(wait_interval=1, finished=True) - round_idx += 1 - - log.info(f"Finish all streaming insertion, num_round={round_idx}") + start_time = time.perf_counter() + round_idx = 0 + + while True: + if len(self.executing_futures) > 200: + log.warning("Skip data insertion this round. There are 200+ unfinished insertion tasks.") + else: + finished, elapsed_time = submit_by_rate() + if finished is True: + log.info( + f"End of dataset, left unfinished={len(self.executing_futures)}, num_round={round_idx}" + ) + break + if elapsed_time >= 1.5: + log.warning( + f"Submit insert tasks took {elapsed_time}s, expected 1s, " + f"indicating potential resource limitations on the client machine.", + ) + + check_and_send_signal(wait_interval=0.001, finished=False) + dur = time.perf_counter() - start_time - round_idx * time_per_batch + if dur < time_per_batch: + time.sleep(time_per_batch - dur) + round_idx += 1 + + # wait for all tasks in executing_futures to complete + while len(self.executing_futures) > 0: + check_and_send_signal(wait_interval=1, finished=True) + round_idx += 1 + + log.info(f"Finish all streaming insertion, num_round={round_idx}")