Skip to content

Commit edcc6c7

Browse files
authored
Minimize rebuilds of RAFT code due to knowhere changes (#264)
* Minimize rebuilds of RAFT code due to knowhere changes Signed-off-by: William Hicks <[email protected]> * Revert accidental reversion of config defaults Signed-off-by: William Hicks <[email protected]> --------- Signed-off-by: William Hicks <[email protected]>
1 parent 48986a8 commit edcc6c7

14 files changed

+428
-385
lines changed

.pre-commit-config.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
default_language_version:
19-
python: python3.8
19+
python: python3
2020
exclude: '^thirdparty'
2121
fail_fast: True
2222
repos:
@@ -32,4 +32,5 @@ repos:
3232
rev: v1.3.5
3333
hooks:
3434
- id: clang-format
35+
types_or: [c, c++, cuda]
3536
args: [-style=file]

CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ include(cmake/utils/utils.cmake)
2222

2323
knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF)
2424
if (WITH_RAFT)
25-
set(CMAKE_CUDA_ARCHITECTURES RAPIDS)
25+
if("${CMAKE_CUDA_ARCHITECTURES}" STREQUAL "")
26+
set(CMAKE_CUDA_ARCHITECTURES RAPIDS)
27+
endif()
2628
include(cmake/libs/librapids.cmake)
2729
project(knowhere CXX C CUDA)
2830
include(cmake/libs/libraft.cmake)
@@ -139,7 +141,7 @@ list(APPEND KNOWHERE_LINKER_LIBS Folly::folly)
139141
add_library(knowhere SHARED ${KNOWHERE_SRCS})
140142
add_dependencies(knowhere ${KNOWHERE_LINKER_LIBS})
141143
if(WITH_RAFT)
142-
list(APPEND KNOWHERE_LINKER_LIBS raft::raft CUDA::cublas CUDA::cusparse CUDA::cusolver)
144+
list(APPEND KNOWHERE_LINKER_LIBS raft::raft raft::compiled_static CUDA::cublas CUDA::cusparse CUDA::cusolver)
143145
endif()
144146
target_link_libraries(knowhere PUBLIC ${KNOWHERE_LINKER_LIBS})
145147
target_include_directories(

cmake/libs/libraft.cmake

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# the License.
1515

1616
add_definitions(-DKNOWHERE_WITH_RAFT)
17+
add_definitions(-DRAFT_EXPLICIT_INSTANTIATE_ONLY)
1718
set(RAFT_VERSION "${RAPIDS_VERSION}")
18-
set(RAFT_FORK "rapidsai")
19-
set(RAFT_PINNED_TAG "branch-23.12")
19+
set(RAFT_FORK "wphicks")
20+
set(RAFT_PINNED_TAG "knowhere-2.4")
2021

2122

2223
rapids_find_package(CUDAToolkit REQUIRED
@@ -38,7 +39,7 @@ function(find_and_configure_raft)
3839
GLOBAL_TARGETS
3940
raft::raft
4041
COMPONENTS
41-
${RAFT_COMPONENTS}
42+
compiled_static
4243
CPM_ARGS
4344
GIT_REPOSITORY
4445
https://github.com/${PKG_FORK}/raft.git
@@ -47,6 +48,7 @@ function(find_and_configure_raft)
4748
SOURCE_SUBDIR
4849
cpp
4950
OPTIONS
51+
"RAFT_COMPILE_LIBRARY ON"
5052
"BUILD_TESTS OFF"
5153
"BUILD_BENCH OFF"
5254
"RAFT_USE_FAISS_STATIC OFF") # Turn this on to build FAISS into your binary

src/common/raft/integration/cagra_index.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
#include "common/raft/proto/raft_index_kind.hpp"
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
1818
#include "common/raft/integration/raft_knowhere_index.cuh"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_EXTERN(cagra, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::cagra>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::cagra>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
26+
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)
27+
1928
namespace raft_knowhere {
2029
template struct raft_knowhere_index<raft_proto::raft_index_kind::cagra>;
2130
} // namespace raft_knowhere
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
18+
#include "common/raft/integration/type_mappers.hpp"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_INSTANTIATION(cagra, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::cagra>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::cagra>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::cagra>,
26+
raft_knowhere::knowhere_bitset_data_type,
27+
raft_knowhere::knowhere_bitset_indexing_type)

src/common/raft/integration/ivf_flat_index.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
#include "common/raft/proto/raft_index_kind.hpp"
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
1818
#include "common/raft/integration/raft_knowhere_index.cuh"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_EXTERN(ivf_flat, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
26+
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)
27+
1928
namespace raft_knowhere {
2029
template struct raft_knowhere_index<raft_proto::raft_index_kind::ivf_flat>;
2130
} // namespace raft_knowhere
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
18+
#include "common/raft/integration/type_mappers.hpp"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_flat, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_flat>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_flat>,
26+
raft_knowhere::knowhere_bitset_data_type,
27+
raft_knowhere::knowhere_bitset_indexing_type)

src/common/raft/integration/ivf_pq_index.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,17 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
#include "common/raft/proto/raft_index_kind.hpp"
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
1818
#include "common/raft/integration/raft_knowhere_index.cuh"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_EXTERN(ivf_pq, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
26+
raft_knowhere::knowhere_bitset_data_type, raft_knowhere::knowhere_bitset_indexing_type)
27+
1928
namespace raft_knowhere {
2029
template struct raft_knowhere_index<raft_proto::raft_index_kind::ivf_pq>;
2130
} // namespace raft_knowhere
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
18+
#include "common/raft/integration/type_mappers.hpp"
19+
#include "common/raft/proto/filtered_search_instantiation.cuh"
20+
#include "common/raft/proto/raft_index_kind.hpp"
21+
22+
RAFT_FILTERED_SEARCH_INSTANTIATION(ivf_pq, raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
23+
raft_knowhere::raft_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
24+
raft_knowhere::raft_input_indexing_t<raft_proto::raft_index_kind::ivf_pq>,
25+
raft_knowhere::raft_data_t<raft_proto::raft_index_kind::ivf_pq>,
26+
raft_knowhere::knowhere_bitset_data_type,
27+
raft_knowhere::knowhere_bitset_indexing_type)

src/common/raft/integration/raft_knowhere_index.cuh

+1-4
Original file line numberDiff line numberDiff line change
@@ -324,14 +324,11 @@ struct raft_knowhere_index<IndexKind>::impl {
324324
using data_type = raft_data_t<index_kind>;
325325
using indexing_type = raft_indexing_t<index_kind>;
326326
using input_indexing_type = raft_input_indexing_t<index_kind>;
327+
using raft_index_type = raft_index_t<index_kind>;
327328

328329
impl() {
329330
}
330331

331-
private:
332-
using raft_index_type = raft_index_t<index_kind>;
333-
334-
public:
335332
auto
336333
is_trained() const {
337334
return index_.has_value();

src/common/raft/integration/raft_knowhere_index.hpp

+1-42
Original file line numberDiff line numberDiff line change
@@ -18,52 +18,11 @@
1818
#include <cstdint>
1919

2020
#include "common/raft/integration/raft_knowhere_config.hpp"
21+
#include "common/raft/integration/type_mappers.hpp"
2122
#include "common/raft/proto/raft_index_kind.hpp"
2223

2324
namespace raft_knowhere {
2425

25-
using knowhere_data_type = float;
26-
using knowhere_indexing_type = std::int64_t;
27-
using knowhere_bitset_data_type = std::uint8_t;
28-
using knowhere_bitset_indexing_type = std::uint32_t;
29-
30-
namespace detail {
31-
32-
template <bool B, raft_proto::raft_index_kind IndexKind>
33-
struct raft_io_type_mapper : std::false_type {};
34-
35-
template <>
36-
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_flat> : std::true_type {
37-
using data_type = float;
38-
using indexing_type = std::int64_t;
39-
using input_indexing_type = std::int64_t;
40-
};
41-
42-
template <>
43-
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_pq> : std::true_type {
44-
using data_type = float;
45-
using indexing_type = std::int64_t;
46-
using input_indexing_type = std::int64_t;
47-
};
48-
49-
template <>
50-
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::cagra> : std::true_type {
51-
using data_type = float;
52-
using indexing_type = std::uint32_t;
53-
using input_indexing_type = std::int64_t;
54-
};
55-
56-
} // namespace detail
57-
58-
template <raft_proto::raft_index_kind IndexKind>
59-
using raft_data_t = typename detail::raft_io_type_mapper<true, IndexKind>::data_type;
60-
61-
template <raft_proto::raft_index_kind IndexKind>
62-
using raft_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::indexing_type;
63-
64-
template <raft_proto::raft_index_kind IndexKind>
65-
using raft_input_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::input_indexing_type;
66-
6726
template <raft_proto::raft_index_kind IndexKind>
6827
struct raft_knowhere_index {
6928
auto static constexpr index_kind = IndexKind;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
#include <cstdint>
19+
#include <type_traits>
20+
21+
#include "common/raft/proto/raft_index_kind.hpp"
22+
23+
namespace raft_knowhere {
24+
25+
using knowhere_data_type = float;
26+
using knowhere_indexing_type = std::int64_t;
27+
using knowhere_bitset_data_type = std::uint8_t;
28+
using knowhere_bitset_indexing_type = std::uint32_t;
29+
30+
namespace detail {
31+
32+
template <bool B, raft_proto::raft_index_kind IndexKind>
33+
struct raft_io_type_mapper : std::false_type {};
34+
35+
template <>
36+
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_flat> : std::true_type {
37+
using data_type = float;
38+
using indexing_type = std::int64_t;
39+
using input_indexing_type = std::int64_t;
40+
};
41+
42+
template <>
43+
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::ivf_pq> : std::true_type {
44+
using data_type = float;
45+
using indexing_type = std::int64_t;
46+
using input_indexing_type = std::uint32_t;
47+
};
48+
49+
template <>
50+
struct raft_io_type_mapper<true, raft_proto::raft_index_kind::cagra> : std::true_type {
51+
using data_type = float;
52+
using indexing_type = std::uint32_t;
53+
using input_indexing_type = std::int64_t;
54+
};
55+
56+
} // namespace detail
57+
58+
template <raft_proto::raft_index_kind IndexKind>
59+
using raft_data_t = typename detail::raft_io_type_mapper<true, IndexKind>::data_type;
60+
61+
template <raft_proto::raft_index_kind IndexKind>
62+
using raft_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::indexing_type;
63+
64+
template <raft_proto::raft_index_kind IndexKind>
65+
using raft_input_indexing_t = typename detail::raft_io_type_mapper<true, IndexKind>::input_indexing_type;
66+
67+
} // namespace raft_knowhere
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2023,NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
#include <raft/core/device_mdspan.hpp>
19+
#include <raft/core/resources.hpp>
20+
#include <raft/neighbors/cagra.cuh>
21+
#include <raft/neighbors/ivf_flat.cuh>
22+
#include <raft/neighbors/ivf_pq.cuh>
23+
#include <raft/neighbors/sample_filter.cuh>
24+
25+
#include "common/raft/proto/raft_index_kind.hpp"
26+
27+
namespace raft_proto {
28+
namespace detail {
29+
template <raft_proto::raft_index_kind K, typename T, typename IdxT>
30+
using index_instantiation = std::conditional_t<
31+
K == raft_proto::raft_index_kind::ivf_flat, raft::neighbors::ivf_flat::index<T, IdxT>,
32+
std::conditional_t<
33+
K == raft_proto::raft_index_kind::ivf_pq, raft::neighbors::ivf_pq::index<IdxT>,
34+
std::conditional_t<K == raft_proto::raft_index_kind::cagra, raft::neighbors::cagra::index<T, IdxT>,
35+
raft::neighbors::ivf_flat::index<T, IdxT>>>>;
36+
} // namespace detail
37+
} // namespace raft_proto
38+
39+
#define RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
40+
template void search_with_filtering<T, IdxT, raft::neighbors::filtering::bitset_filter<BitsetDataT, BitsetIdxT>>( \
41+
raft::resources const&, search_params const&, \
42+
raft_proto::detail::index_instantiation<raft_proto::raft_index_kind::index_type, T, IdxT> const&, \
43+
raft::device_matrix_view<const T, InpIdxT>, raft::device_matrix_view<IdxT, InpIdxT>, \
44+
raft::device_matrix_view<DistT, InpIdxT>, raft::neighbors::filtering::bitset_filter<BitsetDataT, BitsetIdxT>)
45+
46+
#define RAFT_FILTERED_SEARCH_INSTANTIATION(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
47+
namespace raft::neighbors::index_type { \
48+
RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \
49+
}
50+
51+
#define RAFT_FILTERED_SEARCH_EXTERN(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT) \
52+
namespace raft::neighbors::index_type { \
53+
RAFT_FILTERED_SEARCH_TEMPLATE(index_type, T, IdxT, InpIdxT, DistT, BitsetDataT, BitsetIdxT); \
54+
}

0 commit comments

Comments
 (0)