diff --git a/semantic_chunkers/__init__.py b/semantic_chunkers/__init__.py index 560e3a2..f41c50f 100644 --- a/semantic_chunkers/__init__.py +++ b/semantic_chunkers/__init__.py @@ -1,13 +1,18 @@ -from semantic_chunkers.chunkers import BaseChunker -from semantic_chunkers.chunkers import ConsecutiveChunker -from semantic_chunkers.chunkers import CumulativeChunker -from semantic_chunkers.chunkers import StatisticalChunker +from semantic_chunkers.chunkers import ( + BaseChunker, + ConsecutiveChunker, + CumulativeChunker, + StatisticalChunker, +) +from semantic_chunkers.splitters import BaseSplitter, RegexSplitter __all__ = [ "BaseChunker", "ConsecutiveChunker", "CumulativeChunker", "StatisticalChunker", + "BaseSplitter", + "RegexSplitter", ] __version__ = "0.0.5" diff --git a/semantic_chunkers/chunkers/base.py b/semantic_chunkers/chunkers/base.py index 7517c61..4351e57 100644 --- a/semantic_chunkers/chunkers/base.py +++ b/semantic_chunkers/chunkers/base.py @@ -5,12 +5,13 @@ from semantic_router.encoders.base import BaseEncoder from semantic_chunkers.schema import Chunk -from semantic_chunkers.splitters.sentence import regex_splitter +from semantic_chunkers.splitters.base import BaseSplitter class BaseChunker(BaseModel): name: str encoder: BaseEncoder + splitter: BaseSplitter class Config: extra = Extra.allow @@ -19,7 +20,7 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]: raise NotImplementedError("Subclasses must implement this method") def _split(self, doc: str) -> List[str]: - return regex_splitter(doc) + return self.splitter(doc) def _chunk(self, splits: List[Any]) -> List[Chunk]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_chunkers/chunkers/consecutive.py b/semantic_chunkers/chunkers/consecutive.py index 91d7134..1b5664f 100644 --- a/semantic_chunkers/chunkers/consecutive.py +++ b/semantic_chunkers/chunkers/consecutive.py @@ -6,6 +6,8 @@ from semantic_router.encoders.base import BaseEncoder from semantic_chunkers.schema import Chunk from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.splitters.base import BaseSplitter +from semantic_chunkers.splitters.sentence import RegexSplitter class ConsecutiveChunker(BaseChunker): @@ -16,10 +18,11 @@ class ConsecutiveChunker(BaseChunker): def __init__( self, encoder: BaseEncoder, + splitter: BaseSplitter = RegexSplitter(), name: str = "consecutive_chunker", score_threshold: float = 0.45, ): - super().__init__(name=name, encoder=encoder) + super().__init__(name=name, encoder=encoder, splitter=splitter) encoder.score_threshold = score_threshold self.score_threshold = score_threshold diff --git a/semantic_chunkers/chunkers/cumulative.py b/semantic_chunkers/chunkers/cumulative.py index 8987b04..973952e 100644 --- a/semantic_chunkers/chunkers/cumulative.py +++ b/semantic_chunkers/chunkers/cumulative.py @@ -6,6 +6,8 @@ from semantic_router.encoders import BaseEncoder from semantic_chunkers.schema import Chunk from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.splitters.base import BaseSplitter +from semantic_chunkers.splitters.sentence import RegexSplitter class CumulativeChunker(BaseChunker): @@ -17,10 +19,11 @@ class CumulativeChunker(BaseChunker): def __init__( self, encoder: BaseEncoder, + splitter: BaseSplitter = RegexSplitter(), name: str = "cumulative_chunker", score_threshold: float = 0.45, ): - super().__init__(name=name, encoder=encoder) + super().__init__(name=name, encoder=encoder, splitter=splitter) encoder.score_threshold = score_threshold self.score_threshold = score_threshold diff --git a/semantic_chunkers/chunkers/statistical.py b/semantic_chunkers/chunkers/statistical.py index a6997ba..808414e 100644 --- a/semantic_chunkers/chunkers/statistical.py +++ b/semantic_chunkers/chunkers/statistical.py @@ -6,6 +6,8 @@ from semantic_router.encoders.base import BaseEncoder from semantic_chunkers.schema import Chunk from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.splitters.base import BaseSplitter +from semantic_chunkers.splitters.sentence import RegexSplitter from semantic_chunkers.utils.text import tiktoken_length from semantic_chunkers.utils.logger import logger @@ -39,6 +41,7 @@ class StatisticalChunker(BaseChunker): def __init__( self, encoder: BaseEncoder, + splitter: BaseSplitter = RegexSplitter(), name="statistical_chunker", threshold_adjustment=0.01, dynamic_threshold: bool = True, @@ -49,7 +52,7 @@ def __init__( plot_chunks=False, enable_statistics=False, ): - super().__init__(name=name, encoder=encoder) + super().__init__(name=name, encoder=encoder, splitter=splitter) self.calculated_threshold: float self.encoder = encoder self.threshold_adjustment = threshold_adjustment diff --git a/semantic_chunkers/splitters/__init__.py b/semantic_chunkers/splitters/__init__.py index e69de29..c6d858a 100644 --- a/semantic_chunkers/splitters/__init__.py +++ b/semantic_chunkers/splitters/__init__.py @@ -0,0 +1,8 @@ +from semantic_chunkers.splitters.base import BaseSplitter +from semantic_chunkers.splitters.sentence import RegexSplitter + + +__all__ = [ + "BaseSplitter", + "RegexSplitter", +] diff --git a/semantic_chunkers/splitters/base.py b/semantic_chunkers/splitters/base.py new file mode 100644 index 0000000..7969c45 --- /dev/null +++ b/semantic_chunkers/splitters/base.py @@ -0,0 +1,11 @@ +from typing import List + +from pydantic.v1 import BaseModel, Extra + + +class BaseSplitter(BaseModel): + class Config: + extra = Extra.allow + + def __call__(self, doc: str) -> List[str]: + raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_chunkers/splitters/sentence.py b/semantic_chunkers/splitters/sentence.py index 9e75adc..cd8b2b3 100644 --- a/semantic_chunkers/splitters/sentence.py +++ b/semantic_chunkers/splitters/sentence.py @@ -1,7 +1,10 @@ import regex +from typing import List +from semantic_chunkers.splitters.base import BaseSplitter -def regex_splitter(text: str) -> list[str]: + +class RegexSplitter(BaseSplitter): """ Enhanced regex pattern to split a given text into sentences more accurately. @@ -11,13 +14,8 @@ def regex_splitter(text: str) -> list[str]: - Decimal numbers and dates. - Ellipses and other punctuation marks used in informal text. - Removing control characters and format characters. - - Args: - text (str): The text to split into sentences. - - Returns: - list: A list of sentences extracted from the text. """ + regex_pattern = r""" # Negative lookbehind for word boundary, word char, dot, word char (?<!\b\w\.\w.) @@ -51,6 +49,8 @@ def regex_splitter(text: str) -> list[str]: # Matches and removes control characters and format characters [\p{Cc}\p{Cf}]+ """ - sentences = regex.split(regex_pattern, text, flags=regex.VERBOSE) - sentences = [sentence.strip() for sentence in sentences if sentence.strip()] - return sentences + + def __call__(self, doc: str) -> List[str]: + sentences = regex.split(self.regex_pattern, doc, flags=regex.VERBOSE) + sentences = [sentence.strip() for sentence in sentences if sentence.strip()] + return sentences diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 8584dbb..b37e891 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -6,6 +6,7 @@ from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.cohere import CohereEncoder from semantic_chunkers import BaseChunker +from semantic_chunkers import BaseSplitter from semantic_chunkers import ConsecutiveChunker from semantic_chunkers import CumulativeChunker @@ -106,7 +107,13 @@ def base_splitter_instance(): mock_encoder = Mock(spec=BaseEncoder) mock_encoder.name = "mock_encoder" mock_encoder.score_threshold = 0.5 - return BaseChunker(name="test_splitter", encoder=mock_encoder, score_threshold=0.5) + mock_splitter = Mock(spec=BaseSplitter) + return BaseChunker( + name="test_splitter", + encoder=mock_encoder, + splitter=mock_splitter, + score_threshold=0.5, + ) def test_base_splitter_call_not_implemented(base_splitter_instance):