Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 22 additions & 82 deletions src/server/hset_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ extern "C" {
using namespace testing;
using namespace std;
using namespace util;
using namespace boost;
using namespace facade;

namespace dfly {
Expand Down Expand Up @@ -57,76 +58,13 @@ TEST_F(HSetFamilyTest, Basic) {
}

TEST_F(HSetFamilyTest, HSet) {
// Simulate HSET on mirror map
{
absl::flat_hash_map<string, string> mirror; // mirror

// Generate HSET commands and check how many new entries were added
absl::InsecureBitGen gen{};
while (mirror.size() < 600) {
vector<string> cmd = {"HSET", "hash"};
size_t new_values = 0;
for (int i = 0; i < 20; i++) {
string key = GetRandomHex(gen, 3);
string value = GetRandomHex(gen, 20, 10);
new_values += mirror.contains(key) ? 0 : 1;
mirror[key] = value;

cmd.emplace_back(key);
cmd.emplace_back(value);
}

EXPECT_THAT(Run(cmd), IntArg(new_values));
}

// Verify consistency
EXPECT_THAT(Run({"HLEN", "hash"}), IntArg(mirror.size()));
for (const auto& [key, value] : mirror)
EXPECT_EQ(Run({"HGET", "hash", key}), mirror[key]);
}

// HSet with same key twice
Run({"HSET", "hash", "key1", "value1", "key1", "value2"});
EXPECT_EQ(Run({"HGET", "hash", "key1"}), "value2");

// Wrong value cases
EXPECT_THAT(Run({"HSET", "key"}), ErrArg("wrong number of arguments"));
EXPECT_THAT(Run({"HSET", "key", "key"}), ErrArg("wrong number of arguments"));
EXPECT_THAT(Run({"HSET", "key", "key", "value", "key2"}), ErrArg("wrong number of arguments"));
}

TEST_F(HSetFamilyTest, HSetNX) {
// Should create new field
EXPECT_THAT(Run({"HSETNX", "hash", "key1", "value1"}), IntArg(1));
EXPECT_EQ(Run({"HGET", "hash", "key1"}), "value1");

// Should not overwrite
EXPECT_THAT(Run({"HSETNX", "hash", "key1", "value2"}), IntArg(0));
EXPECT_EQ(Run({"HGET", "hash", "key1"}), "value1");

// Wrong value cases
EXPECT_THAT(Run({"HSETNX", "key"}), ErrArg("wrong number of arguments"));
EXPECT_THAT(Run({"HSET", "key", "key"}), ErrArg("wrong number of arguments"));
}

// Listpack handles integers separately, so create a mix of different types
TEST_F(HSetFamilyTest, MixedTypes) {
absl::flat_hash_set<string> str_keys, int_keys;
for (int i = 0; i < 100; i++) {
auto key1 = absl::StrCat("s", i);
auto key2 = absl::StrCat("i", i);
Run({"HSET", "hash", key1, "VALUE", key2, "123456"});
str_keys.emplace(key1);
int_keys.emplace(key2);
}
string val(1024, 'b');

for (string_view key : str_keys)
EXPECT_EQ(Run({{"HGET", "hash", key}}), "VALUE");
EXPECT_EQ(1, CheckedInt({"hset", "large", "a", val}));
EXPECT_EQ(1, CheckedInt({"hlen", "large"}));
EXPECT_EQ(1024, CheckedInt({"hstrlen", "large", "a"}));

for (string_view key : int_keys) {
EXPECT_EQ(Run({{"HGET", "hash", key}}), "123456");
EXPECT_EQ(CheckedInt({"hincrby", "hash", key, "1"}), 123456 + 1);
}
EXPECT_EQ(1, CheckedInt({"hset", "small", "", "565323349817"}));
}

TEST_P(HestFamilyTestProtocolVersioned, Get) {
Expand Down Expand Up @@ -163,22 +101,24 @@ TEST_P(HestFamilyTestProtocolVersioned, Get) {
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3"));
}

TEST_F(HSetFamilyTest, HIncrBy) {
int total = 10;
// Check new field is created
EXPECT_EQ(CheckedInt({"hincrby", "key", "field", "10"}), 10);
EXPECT_EQ(Run({"hget", "key", "field"}), "10");
// Simulate multiple additions
for (int i = -100; i < 100; i += 7) {
total += i;
EXPECT_EQ(CheckedInt({"hincrby", "key", "field", to_string(i)}), total);
}
TEST_F(HSetFamilyTest, HSetNx) {
EXPECT_EQ(1, CheckedInt({"hsetnx", "key", "field", "val"}));
EXPECT_EQ(Run({"hget", "key", "field"}), "val");

EXPECT_EQ(0, CheckedInt({"hsetnx", "key", "field", "val2"}));
EXPECT_EQ(Run({"hget", "key", "field"}), "val");

EXPECT_EQ(1, CheckedInt({"hsetnx", "key", "field2", "val2"}));
EXPECT_EQ(Run({"hget", "key", "field2"}), "val2");

// check dict path
EXPECT_EQ(0, CheckedInt({"hsetnx", "key", "field2", string(512, 'a')}));
EXPECT_EQ(Run({"hget", "key", "field2"}), "val2");
}

// Overflow
Run({"hset", "key", "field2", to_string(numeric_limits<int64_t>::max() - 1)});
EXPECT_THAT(Run({"hincrby", "key", "field2", "2"}), ErrArg("would overflow"));
TEST_F(HSetFamilyTest, HIncr) {
EXPECT_EQ(10, CheckedInt({"hincrby", "key", "field", "10"}));

// Error case
Run({"hset", "key", "a", " 1"});
auto resp = Run({"hincrby", "key", "a", "10"});
EXPECT_THAT(resp, ErrArg("hash value is not an integer"));
Expand Down
7 changes: 1 addition & 6 deletions src/server/replica.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,16 +693,11 @@ error_code Replica::ConsumeRedisStream() {
};
RETURN_ON_ERR(exec_st_.SwitchErrorHandler(std::move(err_handler)));

CmdArgVec args_vector;
facade::CmdArgVec args_vector;

acks_fb_ = fb2::Fiber("redis_acks", &Replica::RedisStreamAcksFb, this);

while (true) {
// Yield if the fiber has been running for long.
if (base::CycleClock::ToUsec(ThisFiber::GetRunningTimeCycles()) > 1000) { // 1ms
ThisFiber::Yield();
}

auto response = ReadRespReply(&io_buf, /*copy_msg=*/false);
if (!response.has_value()) {
LOG_REPL_ERROR("Error in Redis Stream at phase "
Expand Down
5 changes: 2 additions & 3 deletions src/server/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ if (NOT WITH_SEARCH)
return()
endif()

add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc)
add_library(dfly_search_server aggregator.cc doc_accessors.cc doc_index.cc search_family.cc index_join.cc global_vector_index.cc global_vector_search.cc)
target_link_libraries(dfly_search_server dfly_transaction dragonfly_lib dfly_facade redis_lib jsonpath TRDP::jsoncons)


cxx_test(search_family_test dfly_test_lib LABELS DFLY)
cxx_test(aggregator_test dfly_test_lib LABELS DFLY)
cxx_test(index_join_test dfly_test_lib LABELS DFLY)

cxx_test(performance_test dfly_test_lib LABELS DFLY)

add_dependencies(check_dfly search_family_test aggregator_test index_join_test)
103 changes: 98 additions & 5 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "server/engine_shard_set.h"
#include "server/family_utils.h"
#include "server/search/doc_accessors.h"
#include "server/search/global_vector_index.h"
#include "server/server_state.h"

namespace dfly {
Expand Down Expand Up @@ -238,6 +239,11 @@ string_view ShardDocIndex::DocKeyIndex::Get(DocId id) const {
return keys_[id];
}

std::optional<ShardDocIndex::DocId> ShardDocIndex::DocKeyIndex::Find(string_view key) const {
auto it = ids_.find(key);
return it != ids_.end() ? std::make_optional(it->second) : std::nullopt;
}

size_t ShardDocIndex::DocKeyIndex::Size() const {
return ids_.size();
}
Expand Down Expand Up @@ -679,8 +685,11 @@ void ShardDocIndices::DropIndexCache(const dfly::ShardDocIndex& shard_doc_index)
}

void ShardDocIndices::RebuildAllIndices(const OpArgs& op_args) {
for (auto& [_, ptr] : indices_)
for (auto& [index_name, ptr] : indices_) {
ptr->Rebuild(op_args, &local_mr_);
// PoC: Also rebuild global vector indices
ptr->RebuildGlobalVectorIndices(index_name, op_args);
}
}

vector<string> ShardDocIndices::GetIndexNames() const {
Expand All @@ -693,17 +702,23 @@ vector<string> ShardDocIndices::GetIndexNames() const {

void ShardDocIndices::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
DCHECK(IsIndexedKeyType(pv));
for (auto& [_, index] : indices_) {
if (index->Matches(key, pv.ObjType()))
for (auto& [index_name, index] : indices_) {
if (index->Matches(key, pv.ObjType())) {
index->AddDoc(key, db_cntx, pv);
// PoC: Also add to global vector index if document has vector fields
index->AddDocToGlobalVectorIndex(index_name, key, db_cntx, pv);
}
}
}

void ShardDocIndices::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
DCHECK(IsIndexedKeyType(pv));
for (auto& [_, index] : indices_) {
if (index->Matches(key, pv.ObjType()))
for (auto& [index_name, index] : indices_) {
if (index->Matches(key, pv.ObjType())) {
// PoC: Remove from global vector index first (before local removal)
index->RemoveDocFromGlobalVectorIndex(index_name, key, db_cntx, pv);
index->RemoveDoc(key, db_cntx, pv);
}
}
}

Expand All @@ -719,4 +734,82 @@ SearchStats ShardDocIndices::GetStats() const {
return {GetUsedMemory(), indices_.size(), total_entries};
}

// PoC: Global vector index integration
void ShardDocIndex::AddDocToGlobalVectorIndex(std::string_view index_name, std::string_view key,
const DbContext& db_cntx, const PrimeValue& pv) {
if (!indices_)
return;

auto accessor = GetAccessor(db_cntx, pv);
auto local_id = key_index_.Find(key);
if (!local_id)
return;

GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id};

for (const auto& [field_ident, field_info] : base_->schema.fields) {
if (field_info.type == search::SchemaField::VECTOR &&
!(field_info.flags & search::SchemaField::NOINDEX)) {
if (auto vector_info = accessor->GetVector(field_ident); vector_info && vector_info->first) {
const auto& vparams =
std::get<search::SchemaField::VectorParams>(field_info.special_params);
auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex(
index_name, field_info.short_name, vparams);
global_index->AddVector(global_id, key, vector_info->first.get());
}
}
}
}

void ShardDocIndex::RemoveDocFromGlobalVectorIndex(std::string_view index_name,
std::string_view key, const DbContext& db_cntx,
const PrimeValue& pv) {
if (!indices_)
return;

auto local_id = key_index_.Find(key);
if (!local_id)
return;

GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id};

for (const auto& [field_ident, field_info] : base_->schema.fields) {
if (field_info.type == search::SchemaField::VECTOR &&
!(field_info.flags & search::SchemaField::NOINDEX)) {
if (auto global_index = GlobalVectorIndexRegistry::Instance().GetVectorIndex(
index_name, field_info.short_name)) {
global_index->RemoveVector(global_id, key);
}
}
}
}

void ShardDocIndex::RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args) {
if (!indices_)
return;

auto cb = [this, index_name](string_view key, const BaseAccessor& doc) {
auto local_id = key_index_.Find(key);
if (!local_id)
return;

GlobalDocId global_id{EngineShard::tlocal()->shard_id(), *local_id};

for (const auto& [field_ident, field_info] : base_->schema.fields) {
if (field_info.type == search::SchemaField::VECTOR &&
!(field_info.flags & search::SchemaField::NOINDEX)) {
if (auto vector_info = doc.GetVector(field_ident); vector_info && vector_info->first) {
const auto& vparams =
std::get<search::SchemaField::VectorParams>(field_info.special_params);
auto global_index = GlobalVectorIndexRegistry::Instance().GetOrCreateVectorIndex(
index_name, field_info.short_name, vparams);
global_index->AddVector(global_id, key, vector_info->first.get());
}
}
}
};

TraverseAllMatching(*base_, op_args, cb);
}

} // namespace dfly
16 changes: 13 additions & 3 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
namespace dfly {

struct BaseAccessor;
class GlobalVectorIndex; // PoC: Forward declaration for global vector index

using SearchDocData = absl::flat_hash_map<std::string /*field*/, search::SortableValue /*value*/>;
using Synonyms = search::Synonyms;
Expand Down Expand Up @@ -222,6 +223,7 @@ class ShardDocIndex {
std::optional<DocId> Remove(std::string_view key);

std::string_view Get(DocId id) const;
std::optional<DocId> Find(std::string_view key) const; // PoC: Find DocId by key
size_t Size() const;

// Get const reference to the internal ids map
Expand Down Expand Up @@ -287,13 +289,21 @@ class ShardDocIndex {
return key_index_;
}

private:
// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);
// PoC: Global vector index support
void AddDocToGlobalVectorIndex(std::string_view index_name, std::string_view key,
const DbContext& db_cntx, const PrimeValue& pv);
void RemoveDocFromGlobalVectorIndex(std::string_view index_name, std::string_view key,
const DbContext& db_cntx, const PrimeValue& pv);
void RebuildGlobalVectorIndices(std::string_view index_name, const OpArgs& op_args);

// PoC: Public access to LoadEntry for global search coordinator
using LoadedEntry = std::pair<std::string_view, std::unique_ptr<BaseAccessor>>;
std::optional<LoadedEntry> LoadEntry(search::DocId id, const OpArgs& op_args) const;

private:
// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);

// Behaviour identical to SortIndex::Sort for non-sortable fields that need to be fetched first
std::vector<search::SortableValue> KeepTopKSorted(std::vector<DocId>* ids, size_t limit,
const SearchParams::SortOption& sort,
Expand Down
Loading
Loading