Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: splitters module #10

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions semantic_chunkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from semantic_chunkers.chunkers import BaseChunker
from semantic_chunkers.chunkers import ConsecutiveChunker
from semantic_chunkers.chunkers import CumulativeChunker
from semantic_chunkers.chunkers import StatisticalChunker
from semantic_chunkers.chunkers import (
BaseChunker,
ConsecutiveChunker,
CumulativeChunker,
StatisticalChunker,
)
from semantic_chunkers.splitters import BaseSplitter, RegexSplitter

__all__ = [
"BaseChunker",
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"BaseSplitter",
"RegexSplitter",
]

__version__ = "0.0.5"
5 changes: 3 additions & 2 deletions semantic_chunkers/chunkers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.sentence import regex_splitter
from semantic_chunkers.splitters.base import BaseSplitter


class BaseChunker(BaseModel):
name: str
encoder: BaseEncoder
splitter: BaseSplitter

class Config:
extra = Extra.allow
Expand All @@ -19,7 +20,7 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]:
raise NotImplementedError("Subclasses must implement this method")

def _split(self, doc: str) -> List[str]:
return regex_splitter(doc)
return self.splitter(doc)

def _chunk(self, splits: List[Any]) -> List[Chunk]:
raise NotImplementedError("Subclasses must implement this method")
Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/consecutive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


class ConsecutiveChunker(BaseChunker):
Expand All @@ -16,10 +18,11 @@ class ConsecutiveChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name: str = "consecutive_chunker",
score_threshold: float = 0.45,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


class CumulativeChunker(BaseChunker):
Expand All @@ -17,10 +19,11 @@ class CumulativeChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name: str = "cumulative_chunker",
score_threshold: float = 0.45,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

Expand Down
5 changes: 4 additions & 1 deletion semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_chunkers.schema import Chunk
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger

Expand Down Expand Up @@ -39,6 +41,7 @@ class StatisticalChunker(BaseChunker):
def __init__(
self,
encoder: BaseEncoder,
splitter: BaseSplitter = RegexSplitter(),
name="statistical_chunker",
threshold_adjustment=0.01,
dynamic_threshold: bool = True,
Expand All @@ -49,7 +52,7 @@ def __init__(
plot_chunks=False,
enable_statistics=False,
):
super().__init__(name=name, encoder=encoder)
super().__init__(name=name, encoder=encoder, splitter=splitter)
self.calculated_threshold: float
self.encoder = encoder
self.threshold_adjustment = threshold_adjustment
Expand Down
8 changes: 8 additions & 0 deletions semantic_chunkers/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter


__all__ = [
"BaseSplitter",
"RegexSplitter",
]
11 changes: 11 additions & 0 deletions semantic_chunkers/splitters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from pydantic.v1 import BaseModel, Extra


class BaseSplitter(BaseModel):
class Config:
extra = Extra.allow

def __call__(self, doc: str) -> List[str]:
raise NotImplementedError("Subclasses must implement this method")
20 changes: 10 additions & 10 deletions semantic_chunkers/splitters/sentence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import regex
from typing import List

from semantic_chunkers.splitters.base import BaseSplitter

def regex_splitter(text: str) -> list[str]:

class RegexSplitter(BaseSplitter):
"""
Enhanced regex pattern to split a given text into sentences more accurately.

Expand All @@ -11,13 +14,8 @@ def regex_splitter(text: str) -> list[str]:
- Decimal numbers and dates.
- Ellipses and other punctuation marks used in informal text.
- Removing control characters and format characters.

Args:
text (str): The text to split into sentences.

Returns:
list: A list of sentences extracted from the text.
"""

regex_pattern = r"""
# Negative lookbehind for word boundary, word char, dot, word char
(?<!\b\w\.\w.)
Expand Down Expand Up @@ -51,6 +49,8 @@ def regex_splitter(text: str) -> list[str]:
# Matches and removes control characters and format characters
[\p{Cc}\p{Cf}]+
"""
sentences = regex.split(regex_pattern, text, flags=regex.VERBOSE)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
return sentences

def __call__(self, doc: str) -> List[str]:
sentences = regex.split(self.regex_pattern, doc, flags=regex.VERBOSE)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
return sentences
9 changes: 8 additions & 1 deletion tests/unit/test_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_chunkers import BaseChunker
from semantic_chunkers import BaseSplitter
from semantic_chunkers import ConsecutiveChunker
from semantic_chunkers import CumulativeChunker

Expand Down Expand Up @@ -106,7 +107,13 @@ def base_splitter_instance():
mock_encoder = Mock(spec=BaseEncoder)
mock_encoder.name = "mock_encoder"
mock_encoder.score_threshold = 0.5
return BaseChunker(name="test_splitter", encoder=mock_encoder, score_threshold=0.5)
mock_splitter = Mock(spec=BaseSplitter)
return BaseChunker(
name="test_splitter",
encoder=mock_encoder,
splitter=mock_splitter,
score_threshold=0.5,
)


def test_base_splitter_call_not_implemented(base_splitter_instance):
Expand Down
Loading