Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions ragas/src/ragas/testset/transforms/splitters/headline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import typing as t
from dataclasses import dataclass
import tiktoken

from ragas.testset.graph import Node, NodeType, Relationship
from ragas.testset.transforms.base import Splitter
from ragas.utils import num_tokens_from_string


DEFAULT_TOKENIZER = tiktoken.get_encoding("o200k_base")


@dataclass
Expand All @@ -15,27 +20,27 @@ def adjust_chunks(self, chunks):
current_chunk = ""

for chunk in chunks:
chunk_tokens = chunk.split()
chunk_tokens = DEFAULT_TOKENIZER.encode(chunk)

# Split chunks that are over max_tokens
while len(chunk_tokens) > self.max_tokens:
adjusted_chunks.append(" ".join(chunk_tokens[: self.max_tokens]))
chunk_tokens = chunk_tokens[self.max_tokens :]
adjusted_chunks.append(DEFAULT_TOKENIZER.decode(chunk_tokens[:self.max_tokens]))
chunk_tokens = chunk_tokens[self.max_tokens:]

# Handle chunks that are under min_tokens
chunk_str = DEFAULT_TOKENIZER.decode(chunk_tokens)
if len(chunk_tokens) < self.min_tokens:
if current_chunk:
current_chunk += " " + " ".join(chunk_tokens)
if len(current_chunk.split()) >= self.min_tokens:
current_chunk += " " + chunk_str
if num_tokens_from_string(current_chunk, encoding_name=DEFAULT_TOKENIZER.name) >= self.min_tokens:
adjusted_chunks.append(current_chunk)
current_chunk = ""
else:
current_chunk = " ".join(chunk_tokens)
current_chunk = chunk_str
else:
if current_chunk:
adjusted_chunks.append(current_chunk)
current_chunk = ""
adjusted_chunks.append(" ".join(chunk_tokens))
adjusted_chunks.append(chunk_str)

# Append any remaining chunk
if current_chunk:
Expand All @@ -52,7 +57,7 @@ async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]
if headlines is None:
raise ValueError("'headlines' property not found in this node")

if len(text.split()) < self.min_tokens:
if num_tokens_from_string(text, encoding_name=DEFAULT_TOKENIZER.name) < self.min_tokens:
return [node], []
# create the chunks for the different sections
indices = [0]
Expand Down