Skip to content

Commit 0191ec6

Browse files
Add refine for RaBitQ (#1162)
* refine for rabitq Signed-off-by: Alexandr Guzhva <[email protected]> * fix formatting Signed-off-by: Alexandr Guzhva <[email protected]> * fix review Signed-off-by: Alexandr Guzhva <[email protected]> * enable IvfRaBitQ in Milvus Signed-off-by: Alexandr Guzhva <[email protected]> * change the order of RR and Refine Signed-off-by: Alexandr Guzhva <[email protected]> --------- Signed-off-by: Alexandr Guzhva <[email protected]>
1 parent d3d1305 commit 0191ec6

File tree

11 files changed

+492
-241
lines changed

11 files changed

+492
-241
lines changed

include/knowhere/index/index_table.h

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
5858
{IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_BFLOAT16},
5959
// {IndexEnum::INDEX_FAISS_IVFSQ_CC, VecType::VECTOR_INT8},
6060

61+
{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT},
62+
{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_FLOAT16},
63+
{IndexEnum::INDEX_FAISS_IVFRABITQ, VecType::VECTOR_BFLOAT16},
64+
6165
// gpu index
6266
{IndexEnum::INDEX_GPU_BRUTEFORCE, VecType::VECTOR_FLOAT},
6367
{IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT},
@@ -108,6 +112,7 @@ static std::set<std::string> legal_support_mmap_knowhere_index = {
108112
IndexEnum::INDEX_FAISS_SCANN,
109113
IndexEnum::INDEX_FAISS_IVFSQ8,
110114
IndexEnum::INDEX_FAISS_IVFSQ_CC,
115+
IndexEnum::INDEX_FAISS_IVFRABITQ,
111116

112117
// hnsw
113118
IndexEnum::INDEX_HNSW,

src/index/hnsw/faiss_hnsw.cc

+16-205
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "index/hnsw/impl/IndexConditionalWrapper.h"
4141
#include "index/hnsw/impl/IndexHNSWWrapper.h"
4242
#include "index/hnsw/impl/IndexWrapperCosine.h"
43+
#include "index/refine/refine_utils.h"
4344
#include "io/memory_io.h"
4445
#include "knowhere/bitsetview_idselector.h"
4546
#include "knowhere/comp/index_param.h"
@@ -2034,205 +2035,6 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI
20342035
}
20352036
};
20362037

2037-
namespace {
2038-
2039-
// a supporting function
2040-
expected<faiss::ScalarQuantizer::QuantizerType>
2041-
get_sq_quantizer_type(const std::string& sq_type) {
2042-
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {
2043-
{"sq6", faiss::ScalarQuantizer::QT_6bit},
2044-
{"sq8", faiss::ScalarQuantizer::QT_8bit},
2045-
{"fp16", faiss::ScalarQuantizer::QT_fp16},
2046-
{"bf16", faiss::ScalarQuantizer::QT_bf16},
2047-
{"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}};
2048-
2049-
// todo: tolower
2050-
auto sq_type_tolower = str_to_lower(sq_type);
2051-
auto itr = sq_types.find(sq_type_tolower);
2052-
if (itr == sq_types.cend()) {
2053-
return expected<faiss::ScalarQuantizer::QuantizerType>::Err(
2054-
Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower));
2055-
}
2056-
2057-
return itr->second;
2058-
}
2059-
2060-
/*
2061-
// checks whether an index contains a refiner, suitable for a given data format
2062-
std::optional<bool> whether_refine_is_datatype(
2063-
const faiss::Index* index,
2064-
const DataFormatEnum data_format
2065-
) {
2066-
if (index == nullptr) {
2067-
return {};
2068-
}
2069-
2070-
const faiss::IndexRefine* const index_refine = dynamic_cast<const faiss::IndexRefine*>(index);
2071-
if (index_refine == nullptr) {
2072-
return false;
2073-
}
2074-
2075-
switch(data_format) {
2076-
case DataFormatEnum::fp32:
2077-
return (dynamic_cast<const faiss::IndexFlat*>(index_refine->refine_index) != nullptr);
2078-
case DataFormatEnum::fp16:
2079-
{
2080-
const auto* const index_sq = dynamic_cast<const
2081-
faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype ==
2082-
faiss::ScalarQuantizer::QT_fp16);
2083-
}
2084-
case DataFormatEnum::bf16:
2085-
{
2086-
const auto* const index_sq = dynamic_cast<const
2087-
faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype ==
2088-
faiss::ScalarQuantizer::QT_bf16);
2089-
}
2090-
default:
2091-
return {};
2092-
}
2093-
}
2094-
*/
2095-
2096-
expected<bool>
2097-
is_flat_refine(const std::optional<std::string>& refine_type) {
2098-
// grab a type of a refine index
2099-
if (!refine_type.has_value()) {
2100-
return true;
2101-
};
2102-
2103-
// todo: tolower
2104-
std::string refine_type_tolower = str_to_lower(refine_type.value());
2105-
if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") {
2106-
return true;
2107-
};
2108-
2109-
// parse
2110-
auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower);
2111-
if (!refine_sq_type.has_value()) {
2112-
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value();
2113-
return expected<bool>::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value()));
2114-
}
2115-
2116-
return false;
2117-
}
2118-
2119-
bool
2120-
has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_type, DataFormatEnum dataFormat) {
2121-
if (!quant_type.has_value()) {
2122-
return false;
2123-
}
2124-
2125-
auto quant = quant_type.value();
2126-
switch (dataFormat) {
2127-
case DataFormatEnum::fp32:
2128-
return false;
2129-
case DataFormatEnum::fp16:
2130-
return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16;
2131-
case DataFormatEnum::bf16:
2132-
return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16;
2133-
case DataFormatEnum::int8:
2134-
return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed;
2135-
default:
2136-
return false;
2137-
}
2138-
}
2139-
2140-
bool
2141-
has_lossless_refine_index(const FaissHnswConfig& hnsw_cfg, DataFormatEnum dataFormat) {
2142-
bool has_refine = hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value();
2143-
if (has_refine) {
2144-
expected<bool> flat_refine = is_flat_refine(hnsw_cfg.refine_type);
2145-
if (flat_refine.has_value() && flat_refine.value()) {
2146-
return true;
2147-
}
2148-
2149-
auto sq_refine_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value());
2150-
return has_lossless_quant(sq_refine_type, dataFormat);
2151-
}
2152-
return false;
2153-
}
2154-
2155-
// pick a refine index
2156-
expected<std::unique_ptr<faiss::Index>>
2157-
pick_refine_index(const DataFormatEnum data_format, const std::optional<std::string>& refine_type,
2158-
std::unique_ptr<faiss::IndexHNSW>&& hnsw_index) {
2159-
// yes
2160-
2161-
// grab a type of a refine index
2162-
expected<bool> is_fp32_flat = is_flat_refine(refine_type);
2163-
if (!is_fp32_flat.has_value()) {
2164-
return expected<std::unique_ptr<faiss::Index>>::Err(Status::invalid_args, "");
2165-
}
2166-
2167-
const bool is_fp32_flat_v = is_fp32_flat.value();
2168-
2169-
// check input data_format
2170-
if (data_format == DataFormatEnum::fp16) {
2171-
// make sure that we're using fp16 refine
2172-
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());
2173-
if (!(refine_sq_type.has_value() &&
2174-
(refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) {
2175-
LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index.";
2176-
return expected<std::unique_ptr<faiss::Index>>::Err(
2177-
Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index.");
2178-
}
2179-
}
2180-
2181-
if (data_format == DataFormatEnum::bf16) {
2182-
// make sure that we're using bf16 refine
2183-
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());
2184-
if (!(refine_sq_type.has_value() &&
2185-
(refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) {
2186-
LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index.";
2187-
return expected<std::unique_ptr<faiss::Index>>::Err(
2188-
Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index.");
2189-
}
2190-
}
2191-
2192-
// build
2193-
std::unique_ptr<faiss::IndexHNSW> local_hnsw_index = std::move(hnsw_index);
2194-
2195-
// either build flat or sq
2196-
if (is_fp32_flat_v) {
2197-
// build IndexFlat as a refine
2198-
auto refine_index = std::make_unique<faiss::IndexRefineFlat>(local_hnsw_index.get());
2199-
2200-
// let refine_index to own everything
2201-
refine_index->own_fields = true;
2202-
local_hnsw_index.release();
2203-
2204-
// reassign
2205-
return refine_index;
2206-
} else {
2207-
// being IndexScalarQuantizer as a refine
2208-
auto refine_sq_type = get_sq_quantizer_type(refine_type.value());
2209-
2210-
// a redundant check
2211-
if (!refine_sq_type.has_value()) {
2212-
LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value();
2213-
return expected<std::unique_ptr<faiss::Index>>::Err(
2214-
Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value()));
2215-
}
2216-
2217-
// create an sq
2218-
auto sq_refine = std::make_unique<faiss::IndexScalarQuantizer>(
2219-
local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type);
2220-
2221-
auto refine_index = std::make_unique<faiss::IndexRefine>(local_hnsw_index.get(), sq_refine.get());
2222-
2223-
// let refine_index to own everything
2224-
refine_index->own_refine_index = true;
2225-
refine_index->own_fields = true;
2226-
local_hnsw_index.release();
2227-
sq_refine.release();
2228-
2229-
// reassign
2230-
return refine_index;
2231-
}
2232-
}
2233-
2234-
} // namespace
2235-
22362038
//
22372039
class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
22382040
public:
@@ -2300,7 +2102,10 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
23002102

23012103
if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
23022104
// yes
2303-
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
2105+
const auto hnsw_d = hnsw_index->storage->d;
2106+
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
2107+
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
2108+
hnsw_d, hnsw_metric_type);
23042109
if (!final_index_cnd.has_value()) {
23052110
return Status::invalid_args;
23062111
}
@@ -2368,7 +2173,7 @@ class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSW
23682173
return true;
23692174
}
23702175

2371-
return has_lossless_refine_index(hnsw_sq_cfg, datatype_v<DataType>);
2176+
return has_lossless_refine_index(hnsw_sq_cfg.refine, hnsw_sq_cfg.refine_type, datatype_v<DataType>);
23722177
}
23732178
};
23742179

@@ -2449,7 +2254,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode {
24492254
std::unique_ptr<faiss::Index> final_index;
24502255
if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
24512256
// yes
2452-
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
2257+
const auto hnsw_d = hnsw_index->storage->d;
2258+
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
2259+
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
2260+
hnsw_d, hnsw_metric_type);
24532261
if (!final_index_cnd.has_value()) {
24542262
return Status::invalid_args;
24552263
}
@@ -2640,7 +2448,7 @@ class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSW
26402448
static bool
26412449
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
26422450
auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
2643-
return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>);
2451+
return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>);
26442452
}
26452453
};
26462454

@@ -2728,7 +2536,10 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode {
27282536
std::unique_ptr<faiss::Index> final_index;
27292537
if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
27302538
// yes
2731-
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index));
2539+
const auto hnsw_d = hnsw_index->storage->d;
2540+
const auto hnsw_metric_type = hnsw_index->storage->metric_type;
2541+
auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index),
2542+
hnsw_d, hnsw_metric_type);
27322543
if (!final_index_cnd.has_value()) {
27332544
return Status::invalid_args;
27342545
}
@@ -2920,7 +2731,7 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS
29202731
static bool
29212732
StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
29222733
auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
2923-
return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>);
2734+
return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>);
29242735
}
29252736
};
29262737

src/index/ivf/ivf.cc

+53-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "faiss/IndexIVFPQFastScan.h"
2020
#include "faiss/IndexIVFRaBitQ.h"
2121
#include "faiss/IndexIVFScalarQuantizerCC.h"
22-
#include "faiss/IndexPreTransform.h"
2322
#include "faiss/IndexScaNN.h"
2423
#include "faiss/IndexScalarQuantizer.h"
2524
#include "faiss/VectorTransform.h"
@@ -670,9 +669,15 @@ IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSetPtr dataset, std::
670669
if constexpr (std::is_same<IndexIVFRaBitQWrapper, IndexType>::value) {
671670
const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast<const IvfRaBitQConfig&>(*cfg);
672671
auto nlist = MatchNlist(rows, ivf_rabitq_cfg.nlist.value());
673-
auto qb = ivf_rabitq_cfg.rbq_bits_query.value();
674672

675-
index = std::make_unique<IndexIVFRaBitQWrapper>(dim, nlist, qb, metric.value());
673+
DataFormatEnum data_format = DataType2EnumHelper<DataType>::value;
674+
675+
auto result = IndexIVFRaBitQWrapper::create(dim, nlist, ivf_rabitq_cfg, data_format, metric.value());
676+
if (!result.has_value()) {
677+
return result.error();
678+
}
679+
680+
index = std::move(result.value());
676681
index->train(rows, (const float*)data);
677682
}
678683
index_ = std::move(index);
@@ -835,13 +840,36 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
835840

836841
const IvfRaBitQConfig& ivf_rabitq_cfg = static_cast<const IvfRaBitQConfig&>(*cfg);
837842

843+
// use refine?
844+
bool use_refine = false;
845+
846+
const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value();
847+
if (const auto wrapper_index = dynamic_cast<const IndexIVFRaBitQWrapper*>(index_.get());
848+
wrapper_index != nullptr) {
849+
const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index();
850+
use_refine = (refine_index != nullptr);
851+
}
852+
838853
faiss::IVFRaBitQSearchParameters ivf_search_params;
839854
ivf_search_params.nprobe = nprobe;
840855
ivf_search_params.max_codes = 0;
841856
ivf_search_params.sel = id_selector;
842857
ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0);
843858

844-
index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset, &ivf_search_params);
859+
if (use_refine && whether_to_enable_refine) {
860+
// yes, use refine
861+
faiss::IndexRefineSearchParameters refine_search_params;
862+
refine_search_params.sel = id_selector;
863+
refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1);
864+
refine_search_params.base_index_params = &ivf_search_params;
865+
866+
index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset,
867+
&refine_search_params);
868+
} else {
869+
// do not use refine
870+
index_->search(1, cur_query, k, distances.get() + offset, ids.get() + offset,
871+
&ivf_search_params);
872+
}
845873
} else {
846874
auto cur_query = (const float*)data + index * dim;
847875
if (is_cosine) {
@@ -964,7 +992,27 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, std::un
964992
ivf_search_params.sel = id_selector;
965993
ivf_search_params.qb = ivf_rabitq_cfg.rbq_bits_query.value_or(0);
966994

967-
index_->range_search(1, cur_query, radius, &res, &ivf_search_params);
995+
// use refine?
996+
bool use_refine = false;
997+
998+
const bool whether_to_enable_refine = ivf_rabitq_cfg.refine_k.has_value();
999+
if (const auto wrapper_index = dynamic_cast<const IndexIVFRaBitQWrapper*>(index_.get());
1000+
wrapper_index != nullptr) {
1001+
const faiss::IndexRefine* refine_index = wrapper_index->get_refine_index();
1002+
use_refine = (refine_index != nullptr);
1003+
}
1004+
1005+
if (use_refine && whether_to_enable_refine) {
1006+
// yes, use refine
1007+
faiss::IndexRefineSearchParameters refine_search_params;
1008+
refine_search_params.sel = id_selector;
1009+
refine_search_params.k_factor = ivf_rabitq_cfg.refine_k.value_or(1);
1010+
refine_search_params.base_index_params = &ivf_search_params;
1011+
1012+
index_->range_search(1, cur_query, radius, &res, &refine_search_params);
1013+
} else {
1014+
index_->range_search(1, cur_query, radius, &res, &ivf_search_params);
1015+
}
9681016
} else {
9691017
auto cur_query = (const float*)xq + index * dim;
9701018
if (is_cosine) {

0 commit comments

Comments
 (0)