Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: regex chunker #17

Merged
merged 7 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 140 additions & 8 deletions docs/00-chunkers-intro.ipynb

Large diffs are not rendered by default.

148 changes: 140 additions & 8 deletions docs/02-chunkers-async.ipynb

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions semantic_chunkers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BaseChunker,
ConsecutiveChunker,
CumulativeChunker,
RegexChunker,
StatisticalChunker,
)
from semantic_chunkers.splitters import BaseSplitter, RegexSplitter
Expand All @@ -11,8 +12,9 @@
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"BaseSplitter",
"RegexSplitter",
"BaseSplitter",
"RegexChunker",
]

__version__ = "0.0.8"
__version__ = "0.0.9"
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.chunkers.consecutive import ConsecutiveChunker
from semantic_chunkers.chunkers.cumulative import CumulativeChunker
from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.chunkers.statistical import StatisticalChunker

__all__ = [
"BaseChunker",
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"RegexChunker",
]
4 changes: 2 additions & 2 deletions semantic_chunkers/chunkers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Optional

from colorama import Fore, Style
from pydantic.v1 import BaseModel, Extra
Expand All @@ -10,7 +10,7 @@

class BaseChunker(BaseModel):
name: str
encoder: BaseEncoder
encoder: Optional[BaseEncoder]
splitter: BaseSplitter

class Config:
Expand Down
4 changes: 3 additions & 1 deletion semantic_chunkers/chunkers/consecutive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter


class ConsecutiveChunker(BaseChunker):
"""
Called "consecutive sim chunker" because we check the similarities of consecutive document embeddings (compare ith to i+1th document embedding).
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
4 changes: 3 additions & 1 deletion semantic_chunkers/chunkers/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter


class CumulativeChunker(BaseChunker):
Expand All @@ -16,6 +16,8 @@ class CumulativeChunker(BaseChunker):
embeddings of cumulative concatenated documents with the next document.
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
57 changes: 57 additions & 0 deletions semantic_chunkers/chunkers/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
from typing import List, Union

import regex

from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters import RegexSplitter
from semantic_chunkers.utils import text


class RegexChunker(BaseChunker):
def __init__(
self,
max_chunk_tokens: int = 300,
delimiters: List[Union[str, regex.Pattern]] = [],
):
super().__init__(name="regex_chunker", encoder=None, splitter=RegexSplitter())
self.max_chunk_tokens = max_chunk_tokens
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
self.delimiters = delimiters

def __call__(self, docs: list[str]) -> List[List[Chunk]]:
chunks = []
current_chunk = Chunk(
splits=[],
metadata={},
)
current_chunk.token_count = 0

for doc in docs:
regex_splitter = RegexSplitter()
sentences = regex_splitter(doc, delimiters=self.delimiters)
for sentence in sentences:
sentence_token_count = text.tiktoken_length(sentence)

if (
current_chunk.token_count + sentence_token_count
> self.max_chunk_tokens
):
chunks.append(current_chunk)
current_chunk = Chunk(splits=[])
current_chunk.token_count = 0

current_chunk.splits.append(sentence)
if current_chunk.token_count is None:
current_chunk.token_count = 0
current_chunk.token_count += sentence_token_count

# Last chunk
if current_chunk.splits:
chunks.append(current_chunk)

return [chunks]

async def acall(self, docs: list[str]) -> List[List[Chunk]]:
chunks = await asyncio.to_thread(self.__call__, docs)
return chunks
8 changes: 5 additions & 3 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Optional

import numpy as np
from semantic_router.encoders.base import BaseEncoder
Expand All @@ -9,7 +9,7 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter
from semantic_chunkers.utils.logger import logger
from semantic_chunkers.utils.text import (
async_retry_with_timeout,
Expand Down Expand Up @@ -44,6 +44,8 @@ def __str__(self):


class StatisticalChunker(BaseChunker):
encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down Expand Up @@ -104,7 +106,7 @@ def _chunk(
splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_chunk: Chunk | None = None
last_chunk: Optional[Chunk] = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_chunk is not None:
Expand Down
2 changes: 1 addition & 1 deletion semantic_chunkers/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from semantic_chunkers.splitters.base import BaseSplitter
from semantic_chunkers.splitters.sentence import RegexSplitter
from semantic_chunkers.splitters.regex import RegexSplitter

__all__ = [
"BaseSplitter",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

import regex

Expand All @@ -8,13 +8,6 @@
class RegexSplitter(BaseSplitter):
"""
Enhanced regex pattern to split a given text into sentences more accurately.

The enhanced regex pattern includes handling for:
- Direct speech and quotations.
- Abbreviations, initials, and acronyms.
- Decimal numbers and dates.
- Ellipses and other punctuation marks used in informal text.
- Removing control characters and format characters.
"""

regex_pattern = r"""
Expand Down Expand Up @@ -49,9 +42,36 @@ class RegexSplitter(BaseSplitter):
|
# Matches and removes control characters and format characters
[\p{Cc}\p{Cf}]+
# OR
|
# Splits after punctuation marks followed by another punctuation mark
(?<=[\.!?])(?=[\.!?])
# OR
|
# Splits after exclamation or question marks followed by whitespace or end of string
(?<=[!?])(?=\s|$)
"""

def __call__(self, doc: str) -> List[str]:
sentences = regex.split(self.regex_pattern, doc, flags=regex.VERBOSE)
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
def __call__(
self, doc: str, delimiters: List[Union[str, regex.Pattern]] = []
) -> List[str]:
# Ensure the regex pattern is applied last
delimiters.append(regex.compile(self.regex_pattern, flags=regex.VERBOSE))

sentences = [doc]
for delimiter in delimiters:
sentences_for_next_delimiter = []
for sentence in sentences:
if isinstance(delimiter, regex.Pattern):
sub_sentences = delimiter.split(sentence)
split_char = "" # No single character to append for regex pattern
else:
sub_sentences = sentence.split(delimiter)
split_char = delimiter
for i, sub_sentence in enumerate(sub_sentences):
if i < len(sub_sentences) - 1:
sub_sentence += split_char
if sub_sentence.strip():
sentences_for_next_delimiter.append(sub_sentence.strip())
sentences = sentences_for_next_delimiter
return sentences
File renamed without changes.
48 changes: 48 additions & 0 deletions tests/unit/test_regex_chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio
import unittest

from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.utils import text


class TestRegexChunker(unittest.TestCase):
def setUp(self):
self.chunker = RegexChunker(max_chunk_tokens=10)

def test_call(self):
docs = ["This is a test. This is only a test."]
chunks_list = self.chunker(docs)
chunks = chunks_list[0]

self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

def test_acall(self):
docs = ["This is a test. This is only a test."]

async def run_test():
chunks_list = await self.chunker.acall(docs)
chunks = chunks_list[0]
self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

asyncio.run(run_test())


if __name__ == "__main__":
unittest.main()
55 changes: 55 additions & 0 deletions tests/unit/test_regex_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest

from semantic_chunkers.splitters.regex import RegexSplitter


class TestRegexSplitter(unittest.TestCase):
def setUp(self):
self.splitter = RegexSplitter()

def test_split_by_double_newline(self):
doc = "This is the first paragraph.\n\nThis is the second paragraph."
expected = ["This is the first paragraph.", "This is the second paragraph."]
result = self.splitter(doc, delimiters=["\n\n"])
self.assertEqual(result, expected)

def test_split_by_single_newline(self):
doc = "This is the first line.\nThis is the second line."
expected = ["This is the first line.", "This is the second line."]
result = self.splitter(doc, delimiters=["\n"])
self.assertEqual(result, expected)

def test_split_by_period(self):
doc = "This is the first sentence. This is the second sentence."
expected = ["This is the first sentence.", "This is the second sentence."]
result = self.splitter(doc, delimiters=["."])
self.assertEqual(result, expected)

def test_complex_split(self):
doc = """
First paragraph.\n\nSecond paragraph.\nThird line in second paragraph. Fourth line.\n\nFifth paragraph."""
expected = [
"First paragraph.",
"Second paragraph.",
"Third line in second paragraph.",
"Fourth line.",
"Fifth paragraph.",
]
result = self.splitter(doc, delimiters=["\n\n", "\n", "."])
self.assertEqual(result, expected)

def test_custom_delimiters(self):
doc = "First part|Second part|Third part"
expected = ["First part|", "Second part|", "Third part"]
result = self.splitter(doc, delimiters=["|"])
self.assertEqual(result, expected)

def test_regex_split(self):
doc = "This is a sentence. And another one! Yet another?"
expected = ["This is a sentence.", "And another one!", "Yet another?"]
result = self.splitter(doc)
self.assertEqual(result, expected)


if __name__ == "__main__":
unittest.main()
Loading