|
40 | 40 | #include "index/hnsw/impl/IndexConditionalWrapper.h"
|
41 | 41 | #include "index/hnsw/impl/IndexHNSWWrapper.h"
|
42 | 42 | #include "index/hnsw/impl/IndexWrapperCosine.h"
|
| 43 | +#include "index/refine/refine_utils.h" |
43 | 44 | #include "io/memory_io.h"
|
44 | 45 | #include "knowhere/bitsetview_idselector.h"
|
45 | 46 | #include "knowhere/comp/index_param.h"
|
@@ -2034,205 +2035,6 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI
|
2034 | 2035 | }
|
2035 | 2036 | };
|
2036 | 2037 |
|
2037 |
| -namespace { |
2038 |
| - |
2039 |
| -// a supporting function |
2040 |
| -expected<faiss::ScalarQuantizer::QuantizerType> |
2041 |
| -get_sq_quantizer_type(const std::string& sq_type) { |
2042 |
| - std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = { |
2043 |
| - {"sq6", faiss::ScalarQuantizer::QT_6bit}, |
2044 |
| - {"sq8", faiss::ScalarQuantizer::QT_8bit}, |
2045 |
| - {"fp16", faiss::ScalarQuantizer::QT_fp16}, |
2046 |
| - {"bf16", faiss::ScalarQuantizer::QT_bf16}, |
2047 |
| - {"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}}; |
2048 |
| - |
2049 |
| - // todo: tolower |
2050 |
| - auto sq_type_tolower = str_to_lower(sq_type); |
2051 |
| - auto itr = sq_types.find(sq_type_tolower); |
2052 |
| - if (itr == sq_types.cend()) { |
2053 |
| - return expected<faiss::ScalarQuantizer::QuantizerType>::Err( |
2054 |
| - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower)); |
2055 |
| - } |
2056 |
| - |
2057 |
| - return itr->second; |
2058 |
| -} |
2059 |
| - |
2060 |
| -/* |
2061 |
| -// checks whether an index contains a refiner, suitable for a given data format |
2062 |
| -std::optional<bool> whether_refine_is_datatype( |
2063 |
| - const faiss::Index* index, |
2064 |
| - const DataFormatEnum data_format |
2065 |
| -) { |
2066 |
| - if (index == nullptr) { |
2067 |
| - return {}; |
2068 |
| - } |
2069 |
| -
|
2070 |
| - const faiss::IndexRefine* const index_refine = dynamic_cast<const faiss::IndexRefine*>(index); |
2071 |
| - if (index_refine == nullptr) { |
2072 |
| - return false; |
2073 |
| - } |
2074 |
| -
|
2075 |
| - switch(data_format) { |
2076 |
| - case DataFormatEnum::fp32: |
2077 |
| - return (dynamic_cast<const faiss::IndexFlat*>(index_refine->refine_index) != nullptr); |
2078 |
| - case DataFormatEnum::fp16: |
2079 |
| - { |
2080 |
| - const auto* const index_sq = dynamic_cast<const |
2081 |
| -faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == |
2082 |
| -faiss::ScalarQuantizer::QT_fp16); |
2083 |
| - } |
2084 |
| - case DataFormatEnum::bf16: |
2085 |
| - { |
2086 |
| - const auto* const index_sq = dynamic_cast<const |
2087 |
| -faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == |
2088 |
| -faiss::ScalarQuantizer::QT_bf16); |
2089 |
| - } |
2090 |
| - default: |
2091 |
| - return {}; |
2092 |
| - } |
2093 |
| -} |
2094 |
| -*/ |
2095 |
| - |
2096 |
| -expected<bool> |
2097 |
| -is_flat_refine(const std::optional<std::string>& refine_type) { |
2098 |
| - // grab a type of a refine index |
2099 |
| - if (!refine_type.has_value()) { |
2100 |
| - return true; |
2101 |
| - }; |
2102 |
| - |
2103 |
| - // todo: tolower |
2104 |
| - std::string refine_type_tolower = str_to_lower(refine_type.value()); |
2105 |
| - if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") { |
2106 |
| - return true; |
2107 |
| - }; |
2108 |
| - |
2109 |
| - // parse |
2110 |
| - auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower); |
2111 |
| - if (!refine_sq_type.has_value()) { |
2112 |
| - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); |
2113 |
| - return expected<bool>::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); |
2114 |
| - } |
2115 |
| - |
2116 |
| - return false; |
2117 |
| -} |
2118 |
| - |
2119 |
| -bool |
2120 |
| -has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_type, DataFormatEnum dataFormat) { |
2121 |
| - if (!quant_type.has_value()) { |
2122 |
| - return false; |
2123 |
| - } |
2124 |
| - |
2125 |
| - auto quant = quant_type.value(); |
2126 |
| - switch (dataFormat) { |
2127 |
| - case DataFormatEnum::fp32: |
2128 |
| - return false; |
2129 |
| - case DataFormatEnum::fp16: |
2130 |
| - return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16; |
2131 |
| - case DataFormatEnum::bf16: |
2132 |
| - return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16; |
2133 |
| - case DataFormatEnum::int8: |
2134 |
| - return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed; |
2135 |
| - default: |
2136 |
| - return false; |
2137 |
| - } |
2138 |
| -} |
2139 |
| - |
2140 |
| -bool |
2141 |
| -has_lossless_refine_index(const FaissHnswConfig& hnsw_cfg, DataFormatEnum dataFormat) { |
2142 |
| - bool has_refine = hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value(); |
2143 |
| - if (has_refine) { |
2144 |
| - expected<bool> flat_refine = is_flat_refine(hnsw_cfg.refine_type); |
2145 |
| - if (flat_refine.has_value() && flat_refine.value()) { |
2146 |
| - return true; |
2147 |
| - } |
2148 |
| - |
2149 |
| - auto sq_refine_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); |
2150 |
| - return has_lossless_quant(sq_refine_type, dataFormat); |
2151 |
| - } |
2152 |
| - return false; |
2153 |
| -} |
2154 |
| - |
2155 |
| -// pick a refine index |
2156 |
| -expected<std::unique_ptr<faiss::Index>> |
2157 |
| -pick_refine_index(const DataFormatEnum data_format, const std::optional<std::string>& refine_type, |
2158 |
| - std::unique_ptr<faiss::IndexHNSW>&& hnsw_index) { |
2159 |
| - // yes |
2160 |
| - |
2161 |
| - // grab a type of a refine index |
2162 |
| - expected<bool> is_fp32_flat = is_flat_refine(refine_type); |
2163 |
| - if (!is_fp32_flat.has_value()) { |
2164 |
| - return expected<std::unique_ptr<faiss::Index>>::Err(Status::invalid_args, ""); |
2165 |
| - } |
2166 |
| - |
2167 |
| - const bool is_fp32_flat_v = is_fp32_flat.value(); |
2168 |
| - |
2169 |
| - // check input data_format |
2170 |
| - if (data_format == DataFormatEnum::fp16) { |
2171 |
| - // make sure that we're using fp16 refine |
2172 |
| - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2173 |
| - if (!(refine_sq_type.has_value() && |
2174 |
| - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) { |
2175 |
| - LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; |
2176 |
| - return expected<std::unique_ptr<faiss::Index>>::Err( |
2177 |
| - Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); |
2178 |
| - } |
2179 |
| - } |
2180 |
| - |
2181 |
| - if (data_format == DataFormatEnum::bf16) { |
2182 |
| - // make sure that we're using bf16 refine |
2183 |
| - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2184 |
| - if (!(refine_sq_type.has_value() && |
2185 |
| - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) { |
2186 |
| - LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; |
2187 |
| - return expected<std::unique_ptr<faiss::Index>>::Err( |
2188 |
| - Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); |
2189 |
| - } |
2190 |
| - } |
2191 |
| - |
2192 |
| - // build |
2193 |
| - std::unique_ptr<faiss::IndexHNSW> local_hnsw_index = std::move(hnsw_index); |
2194 |
| - |
2195 |
| - // either build flat or sq |
2196 |
| - if (is_fp32_flat_v) { |
2197 |
| - // build IndexFlat as a refine |
2198 |
| - auto refine_index = std::make_unique<faiss::IndexRefineFlat>(local_hnsw_index.get()); |
2199 |
| - |
2200 |
| - // let refine_index to own everything |
2201 |
| - refine_index->own_fields = true; |
2202 |
| - local_hnsw_index.release(); |
2203 |
| - |
2204 |
| - // reassign |
2205 |
| - return refine_index; |
2206 |
| - } else { |
2207 |
| - // being IndexScalarQuantizer as a refine |
2208 |
| - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2209 |
| - |
2210 |
| - // a redundant check |
2211 |
| - if (!refine_sq_type.has_value()) { |
2212 |
| - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); |
2213 |
| - return expected<std::unique_ptr<faiss::Index>>::Err( |
2214 |
| - Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); |
2215 |
| - } |
2216 |
| - |
2217 |
| - // create an sq |
2218 |
| - auto sq_refine = std::make_unique<faiss::IndexScalarQuantizer>( |
2219 |
| - local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type); |
2220 |
| - |
2221 |
| - auto refine_index = std::make_unique<faiss::IndexRefine>(local_hnsw_index.get(), sq_refine.get()); |
2222 |
| - |
2223 |
| - // let refine_index to own everything |
2224 |
| - refine_index->own_refine_index = true; |
2225 |
| - refine_index->own_fields = true; |
2226 |
| - local_hnsw_index.release(); |
2227 |
| - sq_refine.release(); |
2228 |
| - |
2229 |
| - // reassign |
2230 |
| - return refine_index; |
2231 |
| - } |
2232 |
| -} |
2233 |
| - |
2234 |
| -} // namespace |
2235 |
| - |
2236 | 2038 | //
|
2237 | 2039 | class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
|
2238 | 2040 | public:
|
@@ -2300,7 +2102,10 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
|
2300 | 2102 |
|
2301 | 2103 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
|
2302 | 2104 | // yes
|
2303 |
| - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2105 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2106 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2107 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2108 | + hnsw_d, hnsw_metric_type); |
2304 | 2109 | if (!final_index_cnd.has_value()) {
|
2305 | 2110 | return Status::invalid_args;
|
2306 | 2111 | }
|
@@ -2368,7 +2173,7 @@ class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSW
|
2368 | 2173 | return true;
|
2369 | 2174 | }
|
2370 | 2175 |
|
2371 |
| - return has_lossless_refine_index(hnsw_sq_cfg, datatype_v<DataType>); |
| 2176 | + return has_lossless_refine_index(hnsw_sq_cfg.refine, hnsw_sq_cfg.refine_type, datatype_v<DataType>); |
2372 | 2177 | }
|
2373 | 2178 | };
|
2374 | 2179 |
|
@@ -2449,7 +2254,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode {
|
2449 | 2254 | std::unique_ptr<faiss::Index> final_index;
|
2450 | 2255 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
|
2451 | 2256 | // yes
|
2452 |
| - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2257 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2258 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2259 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2260 | + hnsw_d, hnsw_metric_type); |
2453 | 2261 | if (!final_index_cnd.has_value()) {
|
2454 | 2262 | return Status::invalid_args;
|
2455 | 2263 | }
|
@@ -2640,7 +2448,7 @@ class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSW
|
2640 | 2448 | static bool
|
2641 | 2449 | StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
|
2642 | 2450 | auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
|
2643 |
| - return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>); |
| 2451 | + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>); |
2644 | 2452 | }
|
2645 | 2453 | };
|
2646 | 2454 |
|
@@ -2728,7 +2536,10 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode {
|
2728 | 2536 | std::unique_ptr<faiss::Index> final_index;
|
2729 | 2537 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) {
|
2730 | 2538 | // yes
|
2731 |
| - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2539 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2540 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2541 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2542 | + hnsw_d, hnsw_metric_type); |
2732 | 2543 | if (!final_index_cnd.has_value()) {
|
2733 | 2544 | return Status::invalid_args;
|
2734 | 2545 | }
|
@@ -2920,7 +2731,7 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS
|
2920 | 2731 | static bool
|
2921 | 2732 | StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) {
|
2922 | 2733 | auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config);
|
2923 |
| - return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>); |
| 2734 | + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>); |
2924 | 2735 | }
|
2925 | 2736 | };
|
2926 | 2737 |
|
|
0 commit comments