Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace ManagedCode.GraphRag.Benchmarks.Chunking;

[MemoryDiagnoser]
[HideColumns("Error", "StdDev", "RatioSD")]
public class TokenTextChunkerBenchmarks
{
private TokenTextChunker _chunker = null!;
Expand Down Expand Up @@ -36,7 +37,7 @@ public void Setup()
_largeDocument = new[] { new ChunkSlice("doc1", GeneratePlainTextDocument(1_000_000)) };
}

[Benchmark]
[Benchmark(Baseline = true)]
public IReadOnlyList<TextChunk> ChunkSmallDocument()
{
return _chunker.Chunk(_smallDocument, _config);
Expand Down
44 changes: 30 additions & 14 deletions src/ManagedCode.GraphRag/Chunking/TokenTextChunker.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Buffers;
using System.Runtime.InteropServices;
using GraphRag.Config;
using GraphRag.Tokenization;

Expand All @@ -12,48 +14,62 @@ public IReadOnlyList<TextChunk> Chunk(IReadOnlyList<ChunkSlice> slices, Chunking

if (slices.Count == 0)
{
return Array.Empty<TextChunk>();
return [];
}

var tokenizer = TokenizerRegistry.GetTokenizer(config.EncodingModel);
var flattened = new List<(int SliceIndex, int Token)>();

for (var index = 0; index < slices.Count; index++)
{
var slice = slices[index];
var encoded = tokenizer.EncodeToIds(slice.Text);
foreach (var token in encoded)
var encoded = tokenizer.EncodeToIds(slice.Text.AsSpan());
for (var i = 0; i < encoded.Count; i++)
{
var token = encoded[i];
flattened.Add((index, token));
}
}

if (flattened.Count == 0)
{
return Array.Empty<TextChunk>();
return [];
}

var chunkSize = Math.Max(1, config.Size);
var overlap = Math.Clamp(config.Overlap, 0, chunkSize - 1);
var results = new List<TextChunk>();

var step = chunkSize - overlap;
var estimatedChunks = (flattened.Count + step - 1) / step;
var results = new List<TextChunk>(estimatedChunks);

var documentIds = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

var start = 0;
while (start < flattened.Count)
{
var end = Math.Min(flattened.Count, start + chunkSize);
var chunkTokens = flattened.GetRange(start, end - start);
var tokenValues = new int[chunkTokens.Count];
for (var i = 0; i < chunkTokens.Count; i++)
var chunkTokens = CollectionsMarshal.AsSpan(flattened).Slice(start, end - start);
var tokenValues = ArrayPool<int>.Shared.Rent(chunkTokens.Length);
documentIds.Clear();

var lastSliceIndex = -1;
for (var i = 0; i < chunkTokens.Length; i++)
{
var sliceIndex = chunkTokens[i].SliceIndex;
tokenValues[i] = chunkTokens[i].Token;

if (sliceIndex != lastSliceIndex)
{
documentIds.Add(slices[sliceIndex].DocumentId);
lastSliceIndex = sliceIndex;
}
}

var decoded = tokenizer.Decode(tokenValues);
var documentIds = chunkTokens
.Select(tuple => slices[tuple.SliceIndex].DocumentId)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToArray();
var decoded = tokenizer.Decode(new ArraySegment<int>(tokenValues, 0, chunkTokens.Length));
results.Add(new TextChunk(documentIds.ToList(), decoded, chunkTokens.Length));

results.Add(new TextChunk(documentIds, decoded, tokenValues.Length));
ArrayPool<int>.Shared.Return(tokenValues);

if (end >= flattened.Count)
{
Expand Down
144 changes: 144 additions & 0 deletions tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ namespace ManagedCode.GraphRag.Tests.Chunking;
public sealed class TokenTextChunkerTests
{
private readonly TokenTextChunker _chunker = new();
private readonly ChunkingConfig _defaultConfig = new()
{
Size = 40,
Overlap = 10,
EncodingModel = TokenizerDefaults.DefaultEncoding
};

[Fact]
public void Chunk_RespectsTokenBudget()
Expand Down Expand Up @@ -63,4 +69,142 @@ public void Chunk_CombinesDocumentIdentifiersAcrossSlices()
Assert.Contains(chunks, chunk => chunk.DocumentIds.Contains("doc-1"));
Assert.Contains(chunks, chunk => chunk.DocumentIds.Contains("doc-2"));
}

[Fact]
public void Chunk_OverlapProducesSharedTokensBetweenAdjacentChunks()
{
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);
const string text = "The quick brown fox jumps over the lazy dog and continues running through the forest until it reaches the river where it stops to drink some water.";
var slices = new[] { new ChunkSlice("doc-1", text) };

var config = new ChunkingConfig
{
Size = 20,
Overlap = 5,
EncodingModel = TokenizerDefaults.DefaultEncoding
};

var chunks = _chunker.Chunk(slices, config);

Assert.True(chunks.Count >= 2, "Need at least 2 chunks to verify overlap");

for (var i = 0; i < chunks.Count - 1; i++)
{
var currentChunkTokens = tokenizer.EncodeToIds(chunks[i].Text);
var nextChunkTokens = tokenizer.EncodeToIds(chunks[i + 1].Text);

var lastTokensOfCurrent = currentChunkTokens.TakeLast(config.Overlap).ToArray();
var firstTokensOfNext = nextChunkTokens.Take(config.Overlap).ToArray();

Assert.Equal(lastTokensOfCurrent, firstTokensOfNext);
}
}

[Fact]
public void Chunk_EmptySlicesReturnsEmptyResult()
{
var slices = Array.Empty<ChunkSlice>();

var chunks = _chunker.Chunk(slices, _defaultConfig);

Assert.Empty(chunks);
}

[Fact]
public void Chunk_SlicesWithEmptyTextReturnsEmptyResult()
{
var slices = new[] { new ChunkSlice("doc-1", string.Empty) };

var chunks = _chunker.Chunk(slices, _defaultConfig);

Assert.Empty(chunks);
}

[Fact]
public void Chunk_NullSlicesThrowsArgumentNullException()
{
Assert.Throws<ArgumentNullException>(() => _chunker.Chunk(null!, _defaultConfig));
}

[Fact]
public void Chunk_NullConfigThrowsArgumentNullException()
{
var slices = new[] { new ChunkSlice("doc-1", "Some text") };

Assert.Throws<ArgumentNullException>(() => _chunker.Chunk(slices, null!));
}

[Fact]
public void Chunk_ZeroOverlapProducesNonOverlappingChunks()
{
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);
const string text = "The quick brown fox jumps over the lazy dog and continues running through the forest until it reaches the river.";
var slices = new[] { new ChunkSlice("doc-1", text) };

var config = new ChunkingConfig
{
Size = 15,
Overlap = 0,
EncodingModel = TokenizerDefaults.DefaultEncoding
};

var chunks = _chunker.Chunk(slices, config);
Assert.True(chunks.Count >= 2, "Need at least 2 chunks to verify zero overlap");

var allChunkTokens = chunks
.SelectMany(c => tokenizer.EncodeToIds(c.Text))
.ToList();

var originalTokens = tokenizer.EncodeToIds(text);

Assert.Equal(originalTokens.Count, allChunkTokens.Count);
}

[Fact]
public void Chunk_InputSmallerThanChunkSizeReturnsSingleChunk()
{
const string shortText = "Hello world";
var slices = new[] { new ChunkSlice("doc-1", shortText) };

var config = new ChunkingConfig
{
Size = 100,
Overlap = 10,
EncodingModel = TokenizerDefaults.DefaultEncoding
};

var chunks = _chunker.Chunk(slices, config);

Assert.Single(chunks);
Assert.Equal(shortText, chunks[0].Text);
}

[Fact]
public void Chunk_ExactBoundaryProducesExpectedChunkCount()
{
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);

const int chunkSize = 10;
const int overlap = 2;
const int step = chunkSize - overlap;

var targetTokenCount = step * 3 + overlap;
var words = Enumerable.Range(0, targetTokenCount * 2).Select(i => "word").ToArray();
var text = string.Join(" ", words);

var actualTokens = tokenizer.EncodeToIds(text);
var slices = new[] { new ChunkSlice("doc-1", text) };

var config = new ChunkingConfig
{
Size = chunkSize,
Overlap = overlap,
EncodingModel = TokenizerDefaults.DefaultEncoding
};

var chunks = _chunker.Chunk(slices, config);

Assert.True(chunks.Count >= 2, "Should produce multiple chunks");
Assert.All(chunks.SkipLast(1), chunk => Assert.Equal(chunkSize, chunk.TokenCount));
}
}