Skip to content

Commit

Permalink
feat: statistical chunker improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Jul 3, 2024
1 parent d3d4b16 commit 0b1c896
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 1,437 deletions.
1,522 changes: 138 additions & 1,384 deletions docs/00-chunkers-intro.ipynb

Large diffs are not rendered by default.

135 changes: 82 additions & 53 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from dataclasses import dataclass
from typing import Any, List

Expand All @@ -10,7 +11,7 @@
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.utils.logger import logger
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.text import async_retry_with_timeout, tiktoken_length, time_it


@dataclass
Expand Down Expand Up @@ -54,7 +55,6 @@ def __init__(
enable_statistics=False,
):
super().__init__(name=name, encoder=encoder, splitter=splitter)
self.calculated_threshold: float
self.encoder = encoder
self.threshold_adjustment = threshold_adjustment
self.dynamic_threshold = dynamic_threshold
Expand All @@ -67,6 +67,7 @@ def __init__(
self.statistics: ChunkStatistics
self.DEFAULT_THRESHOLD = 0.5

@time_it
def _chunk(
self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
) -> List[Chunk]:
Expand Down Expand Up @@ -99,44 +100,58 @@ def _chunk(
splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_split = None
last_chunk: Chunk | None = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_split is not None:
batch_splits = last_split.splits + batch_splits
if last_chunk is not None:
batch_splits = last_chunk.splits + batch_splits

encoded_splits = self._encode_documents(batch_splits)
similarities = self._calculate_similarity_scores(encoded_splits)

if self.dynamic_threshold:
self._find_optimal_threshold(batch_splits, similarities)
calculated_threshold = self._find_optimal_threshold(
batch_splits, similarities
)
else:
self.calculated_threshold = (
calculated_threshold = (
self.encoder.score_threshold
if self.encoder.score_threshold
else self.DEFAULT_THRESHOLD
)
split_indices = self._find_split_indices(similarities=similarities)
split_indices = self._find_split_indices(
similarities=similarities, calculated_threshold=calculated_threshold
)

doc_chunks = self._split_documents(
batch_splits, split_indices, similarities
docs=batch_splits,
split_indices=split_indices,
similarities=similarities,
)

if len(doc_chunks) > 1:
chunks.extend(doc_chunks[:-1])
last_split = doc_chunks[-1]
last_chunk = doc_chunks[-1]
else:
last_split = doc_chunks[0]
last_chunk = doc_chunks[0]

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)
self.plot_similarity_scores(
similarities=similarities,
split_indices=split_indices,
chunks=doc_chunks,
calculated_threshold=calculated_threshold,
)

if self.enable_statistics:
print(self.statistics)

if last_split:
chunks.append(last_split)
if last_chunk:
chunks.append(last_chunk)

return chunks

@time_it
async def _async_chunk(
self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
) -> List[Chunk]:
Expand Down Expand Up @@ -168,45 +183,50 @@ async def _async_chunk(

splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_split = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_split is not None:
batch_splits = last_split.splits + batch_splits
chunks: list[Chunk] = []

# Step 1: Define process_batch as a separate coroutine function for parallel
async def _process_batch(batch_splits: List[str]):
encoded_splits = await self._async_encode_documents(batch_splits)
return batch_splits, encoded_splits

# Step 2: Create tasks for parallel execution
tasks = []
for i in range(0, len(splits), batch_size):
batch_splits = splits[i : i + batch_size]
tasks.append(_process_batch(batch_splits))

# Step 3: Await tasks and collect results
encoded_split_results = await asyncio.gather(*tasks)

# Step 4: Sequentially process results
for batch_splits, encoded_splits in encoded_split_results:
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(batch_splits, similarities)
calculated_threshold = self._find_optimal_threshold(
batch_splits, similarities
)
else:
self.calculated_threshold = (
calculated_threshold = (
self.encoder.score_threshold
if self.encoder.score_threshold
else self.DEFAULT_THRESHOLD
)
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(
batch_splits, split_indices, similarities
split_indices = self._find_split_indices(
similarities=similarities, calculated_threshold=calculated_threshold
)

if len(doc_chunks) > 1:
chunks.extend(doc_chunks[:-1])
last_split = doc_chunks[-1]
else:
last_split = doc_chunks[0]

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)

if self.enable_statistics:
print(self.statistics)

if last_split:
chunks.append(last_split)
doc_chunks: list[Chunk] = self._split_documents(
docs=batch_splits,
split_indices=split_indices,
similarities=similarities,
)

chunks.extend(doc_chunks)
return chunks

@time_it
def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
"""Split documents into smaller chunks based on semantic similarity.
Expand Down Expand Up @@ -235,6 +255,7 @@ def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
raise ValueError("The document must be a string.")
return all_chunks

@time_it
async def acall(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
"""Split documents into smaller chunks based on semantic similarity.
Expand Down Expand Up @@ -263,6 +284,7 @@ async def acall(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]
raise ValueError("The document must be a string.")
return all_chunks

@time_it
def _encode_documents(self, docs: List[str]) -> np.ndarray:
"""
Encodes a list of documents into embeddings. If the number of documents
Expand All @@ -286,6 +308,8 @@ def _encode_documents(self, docs: List[str]) -> np.ndarray:

return np.array(embeddings)

@async_retry_with_timeout(retries=3, timeout=5)
@time_it
async def _async_encode_documents(self, docs: List[str]) -> np.ndarray:
"""
Encodes a list of documents into embeddings. If the number of documents
Expand Down Expand Up @@ -321,14 +345,16 @@ def _calculate_similarity_scores(self, encoded_docs: np.ndarray) -> List[float]:
raw_similarities.append(curr_sim_score)
return raw_similarities

def _find_split_indices(self, similarities: List[float]) -> List[int]:
def _find_split_indices(
self, similarities: List[float], calculated_threshold: float
) -> List[int]:
split_indices = []
for idx, score in enumerate(similarities):
logger.debug(f"Similarity score at index {idx}: {score}")
if score < self.calculated_threshold:
if score < calculated_threshold:
logger.debug(
f"Adding to split_indices due to score < threshold: "
f"{score} < {self.calculated_threshold}"
f"{score} < {calculated_threshold}"
)
# Chunk after the document at idx
split_indices.append(idx + 1)
Expand All @@ -348,11 +374,14 @@ def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float

iteration = 0
median_tokens = 0
calculated_threshold = 0
while low <= high:
self.calculated_threshold = (low + high) / 2
split_indices = self._find_split_indices(similarity_scores)
calculated_threshold = (low + high) / 2
split_indices = self._find_split_indices(
similarity_scores, calculated_threshold
)
logger.debug(
f"Iteration {iteration}: Trying threshold: {self.calculated_threshold}"
f"Iteration {iteration}: Trying threshold: {calculated_threshold}"
)

# Calculate the token counts for each split using the cumulative sums
Expand All @@ -376,20 +405,20 @@ def _find_optimal_threshold(self, docs: List[str], similarity_scores: List[float
logger.debug("Median tokens in target range. Stopping iteration.")
break
elif median_tokens < self.min_split_tokens:
high = self.calculated_threshold - self.threshold_adjustment
high = calculated_threshold - self.threshold_adjustment
logger.debug(f"Iteration {iteration}: Adjusting high to {high}")
else:
low = self.calculated_threshold + self.threshold_adjustment
low = calculated_threshold + self.threshold_adjustment
logger.debug(f"Iteration {iteration}: Adjusting low to {low}")
iteration += 1

logger.debug(
f"Optimal threshold {self.calculated_threshold} found "
f"Optimal threshold {calculated_threshold} found "
f"with median tokens ({median_tokens}) in target range "
f"({self.min_split_tokens}-{self.max_split_tokens})."
)

return self.calculated_threshold
return calculated_threshold

def _split_documents(
self, docs: List[str], split_indices: List[int], similarities: List[float]
Expand Down Expand Up @@ -440,7 +469,7 @@ def _split_documents(
)
logger.debug(
f"Chunk finalized with {current_tokens_count} tokens due to "
f"threshold {self.calculated_threshold}."
f"threshold {triggered_score}."
)
current_split, current_tokens_count = [], 0
chunks_by_threshold += 1
Expand Down Expand Up @@ -528,6 +557,7 @@ def plot_similarity_scores(
similarities: List[float],
split_indices: List[int],
chunks: list[Chunk],
calculated_threshold: float,
):
try:
from matplotlib import pyplot as plt
Expand All @@ -550,7 +580,7 @@ def plot_similarity_scores(
label="Chunk" if split_index == split_indices[0] else "",
)
axs[0].axhline(
y=self.calculated_threshold,
y=calculated_threshold,
color="g",
linestyle="-.",
label="Threshold Similarity Score",
Expand All @@ -569,8 +599,7 @@ def plot_similarity_scores(
axs[0].set_xlabel("Document Segment Index")
axs[0].set_ylabel("Similarity Score")
axs[0].set_title(
f"Threshold: {self.calculated_threshold} |"
f" Window Size: {self.window_size}",
f"Threshold: {calculated_threshold} |" f" Window Size: {self.window_size}",
loc="right",
fontsize=10,
)
Expand Down
48 changes: 48 additions & 0 deletions semantic_chunkers/utils/text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,55 @@
import asyncio
from functools import wraps
import tiktoken
import time
from semantic_chunkers.utils.logger import logger


def tiktoken_length(text: str) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text, disallowed_special=())
return len(tokens)


def time_it(func):
async def async_wrapper(*args, **kwargs):
start_time = time.time()
result = await func(*args, **kwargs) # Await the async function
end_time = time.time()
logger.debug(f"{func.__name__} duration: {end_time - start_time:.2f} seconds")
return result

def sync_wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs) # Call the sync function directly
end_time = time.time()
logger.debug(f"{func.__name__} duration: {end_time - start_time:.2f} seconds")
return result

if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper


def async_retry_with_timeout(retries=3, timeout=10):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
for attempt in range(retries):
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout)
except asyncio.TimeoutError:
logger.warning(
f"Timeout on attempt {attempt+1} for {func.__name__}"
)
except Exception as e:
logger.error(
f"Exception on attempt {attempt+1} for {func.__name__}: {e}"
)
if attempt == retries - 1:
raise
else:
await asyncio.sleep(2**attempt) # Exponential backoff
return wrapper
return decorator

0 comments on commit 0b1c896

Please sign in to comment.