Skip to content

Commit

Permalink
chore: regex chunker
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Jul 19, 2024
1 parent 30182ba commit 54a65ff
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 18 deletions.
148 changes: 140 additions & 8 deletions docs/00-chunkers-intro.ipynb

Large diffs are not rendered by default.

148 changes: 140 additions & 8 deletions docs/02-chunkers-async.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion semantic_chunkers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
BaseChunker,
ConsecutiveChunker,
CumulativeChunker,
RegexChunker,
StatisticalChunker,
)
from semantic_chunkers.splitters import BaseSplitter, RegexSplitter
Expand All @@ -13,6 +14,7 @@
"StatisticalChunker",
"RegexSplitter",
"BaseSplitter",
"RegexChunker",
]

__version__ = "0.0.8"
__version__ = "0.0.9"
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.chunkers.consecutive import ConsecutiveChunker
from semantic_chunkers.chunkers.cumulative import CumulativeChunker
from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.chunkers.statistical import StatisticalChunker

__all__ = [
"BaseChunker",
"ConsecutiveChunker",
"CumulativeChunker",
"StatisticalChunker",
"RegexChunker",
]
2 changes: 1 addition & 1 deletion semantic_chunkers/chunkers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class BaseChunker(BaseModel):
name: str
encoder: BaseEncoder
encoder: BaseEncoder | None
splitter: BaseSplitter

class Config:
Expand Down
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/consecutive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class ConsecutiveChunker(BaseChunker):
Called "consecutive sim chunker" because we check the similarities of consecutive document embeddings (compare ith to i+1th document embedding).
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class CumulativeChunker(BaseChunker):
embeddings of cumulative concatenated documents with the next document.
"""

encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
50 changes: 50 additions & 0 deletions semantic_chunkers/chunkers/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
from typing import List

from semantic_chunkers.chunkers.base import BaseChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.splitters import RegexSplitter
from semantic_chunkers.utils import text


class RegexChunker(BaseChunker):
def __init__(self, max_chunk_tokens: int = 300):
super().__init__(name="regex_chunker", encoder=None, splitter=RegexSplitter())
self.max_chunk_tokens = max_chunk_tokens

def __call__(self, docs: list[str]) -> List[List[Chunk]]:
chunks = []
current_chunk = Chunk(
splits=[],
metadata={},
)
current_chunk.token_count = 0

for doc in docs:
regex_splitter = RegexSplitter()
sentences = regex_splitter(doc)
for sentence in sentences:
sentence_token_count = text.tiktoken_length(sentence)

if (
current_chunk.token_count + sentence_token_count
> self.max_chunk_tokens
):
chunks.append(current_chunk)
current_chunk = Chunk(splits=[])
current_chunk.token_count = 0

current_chunk.splits.append(sentence)
if current_chunk.token_count is None:
current_chunk.token_count = 0
current_chunk.token_count += sentence_token_count

# Last chunk
if current_chunk.splits:
chunks.append(current_chunk)

return [chunks]

async def acall(self, docs: list[str]) -> List[List[Chunk]]:
chunks = await asyncio.to_thread(self.__call__, docs)
return chunks
2 changes: 2 additions & 0 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __str__(self):


class StatisticalChunker(BaseChunker):
encoder: BaseEncoder

def __init__(
self,
encoder: BaseEncoder,
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_regex_chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import asyncio
import unittest

from semantic_chunkers.chunkers.regex import RegexChunker
from semantic_chunkers.schema import Chunk
from semantic_chunkers.utils import text


class TestRegexChunker(unittest.TestCase):
def setUp(self):
self.chunker = RegexChunker(max_chunk_tokens=10)

def test_call(self):
docs = ["This is a test. This is only a test."]
chunks_list = self.chunker(docs)
chunks = chunks_list[0]

self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

def test_acall(self):
docs = ["This is a test. This is only a test."]

async def run_test():
chunks_list = await self.chunker.acall(docs)
chunks = chunks_list[0]
self.assertIsInstance(chunks, list)
self.assertTrue(all(isinstance(chunk, Chunk) for chunk in chunks))
self.assertGreater(len(chunks), 0)
self.assertTrue(
all(
text.tiktoken_length(chunk.content) <= self.chunker.max_chunk_tokens
for chunk in chunks
)
)

asyncio.run(run_test())


if __name__ == "__main__":
unittest.main()

0 comments on commit 54a65ff

Please sign in to comment.