Skip to content

Commit ded03b8

Browse files
jiayulufacebook-github-bot
authored andcommitted
put KVTensorWrapper in its own header (#3575)
Summary: Pull Request resolved: #3575 X-link: facebookresearch/FBGEMM#660 some consumers of KVTensorWrapper build cpu-only packages. this diff made the following changes to avoid linking against cuda libraries: - put KVTensorWrapper in its own header file - add a dummy cpu target for KVTensorWrapper Reviewed By: q10, sryap Differential Revision: D68060586 fbshipit-source-id: 3fb4ade32108d557d2e1d19b629449867f0f0e7b
1 parent 9e9aa93 commit ded03b8

File tree

4 files changed

+222
-97
lines changed

4 files changed

+222
-97
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/Tensor.h> // @manual=//caffe2:ATen-core
12+
#include <torch/custom_class.h>
13+
14+
namespace ssd {
15+
16+
class EmbeddingRocksDB;
17+
class EmbeddingRocksDBWrapper;
18+
class SnapshotHandle;
19+
20+
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
21+
struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
22+
explicit EmbeddingSnapshotHandleWrapper(
23+
const SnapshotHandle* handle,
24+
std::shared_ptr<EmbeddingRocksDB> db);
25+
26+
~EmbeddingSnapshotHandleWrapper();
27+
28+
const SnapshotHandle* handle;
29+
std::shared_ptr<EmbeddingRocksDB> db;
30+
};
31+
32+
class KVTensorWrapper : public torch::jit::CustomClassHolder {
33+
public:
34+
explicit KVTensorWrapper(
35+
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
36+
std::vector<int64_t> shape,
37+
int64_t dtype,
38+
int64_t row_offset,
39+
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
40+
snapshot_handle);
41+
42+
at::Tensor narrow(int64_t dim, int64_t start, int64_t length);
43+
44+
void set_range(
45+
int64_t dim,
46+
const int64_t start,
47+
const int64_t length,
48+
const at::Tensor& weights);
49+
50+
c10::IntArrayRef size();
51+
52+
c10::ScalarType dtype();
53+
54+
std::string_view dtype_str();
55+
56+
c10::Device device();
57+
58+
std::string device_str();
59+
60+
std::string layout_str();
61+
62+
private:
63+
std::shared_ptr<EmbeddingRocksDB> db_;
64+
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
65+
at::TensorOptions options_;
66+
std::vector<int64_t> shape_;
67+
int64_t row_offset_;
68+
};
69+
70+
} // namespace ssd
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/core/op_registration/op_registration.h>
11+
#include <c10/core/ScalarTypeToTypeMeta.h>
12+
#include <torch/library.h>
13+
14+
#include "./kv_tensor_wrapper.h"
15+
#include "common/base/Exception.h"
16+
17+
using namespace at;
18+
using namespace ssd;
19+
20+
namespace ssd {
21+
class EmbeddingRocksDB {};
22+
23+
// @lint-ignore CLANGTIDY facebook-hte-ShadowingClass
24+
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
25+
private:
26+
friend class KVTensorWrapper;
27+
std::shared_ptr<EmbeddingRocksDB> impl_;
28+
};
29+
30+
class SnapshotHandle {};
31+
32+
KVTensorWrapper::KVTensorWrapper(
33+
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
34+
std::vector<int64_t> shape,
35+
[[maybe_unused]] int64_t dtype,
36+
int64_t row_offset,
37+
[[maybe_unused]] std::optional<
38+
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle)
39+
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
40+
: db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) {
41+
FBEXCEPTION("Not implemented");
42+
}
43+
44+
at::Tensor KVTensorWrapper::narrow(
45+
[[maybe_unused]] int64_t dim,
46+
[[maybe_unused]] int64_t start,
47+
[[maybe_unused]] int64_t length) {
48+
FBEXCEPTION("Not implemented");
49+
return at::empty(c10::IntArrayRef({1, 1}), options_);
50+
}
51+
52+
void KVTensorWrapper::set_range(
53+
[[maybe_unused]] int64_t dim,
54+
[[maybe_unused]] const int64_t start,
55+
[[maybe_unused]] const int64_t length,
56+
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
57+
[[maybe_unused]] const at::Tensor& weights) {
58+
FBEXCEPTION("Not implemented");
59+
}
60+
61+
c10::IntArrayRef KVTensorWrapper::size() {
62+
FBEXCEPTION("Not implemented");
63+
return shape_;
64+
}
65+
66+
c10::ScalarType KVTensorWrapper::dtype() {
67+
FBEXCEPTION("Not implemented");
68+
return options_.dtype().toScalarType();
69+
}
70+
71+
std::string_view KVTensorWrapper::dtype_str() {
72+
FBEXCEPTION("Not implemented");
73+
return scalarTypeToTypeMeta(dtype()).name();
74+
}
75+
76+
c10::Device KVTensorWrapper::device() {
77+
FBEXCEPTION("Not implemented");
78+
return options_.device();
79+
}
80+
81+
std::string KVTensorWrapper::device_str() {
82+
FBEXCEPTION("Not implemented");
83+
return device().str();
84+
}
85+
86+
std::string KVTensorWrapper::layout_str() {
87+
FBEXCEPTION("Not implemented");
88+
std::ostringstream oss;
89+
oss << options_.layout();
90+
return oss.str();
91+
}
92+
} // namespace ssd

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,46 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
386386
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
387387
};
388388

389+
SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
390+
auto num_shards = db->num_shards();
391+
CHECK_GT(num_shards, 0);
392+
shard_snapshots_.reserve(num_shards);
393+
for (auto shard = 0; shard < num_shards; ++shard) {
394+
const auto* snapshot = db->dbs_[shard]->GetSnapshot();
395+
CHECK(snapshot != nullptr)
396+
<< "ERROR: create_snapshot fails to create a snapshot "
397+
<< "for db shard " << shard << ". Please make sure that "
398+
<< "inplace_update_support is set to false" << std::endl;
399+
shard_snapshots_.push_back(snapshot);
400+
}
401+
}
402+
403+
SnapshotHandle::~SnapshotHandle() {
404+
for (auto shard = 0; shard < db_->dbs_.size(); ++shard) {
405+
snapshot_ptr_t snapshot = shard_snapshots_[shard];
406+
CHECK(snapshot != nullptr) << "Unexpected nullptr for snapshot " << shard;
407+
db_->dbs_[shard]->ReleaseSnapshot(snapshot);
408+
}
409+
}
410+
411+
void SnapshotHandle::release() {
412+
db_->release_snapshot(this);
413+
}
414+
415+
snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
416+
CHECK_LE(shard, shard_snapshots_.size());
417+
return shard_snapshots_[shard];
418+
}
419+
420+
EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
421+
const SnapshotHandle* handle,
422+
std::shared_ptr<EmbeddingRocksDB> db)
423+
: handle(handle), db(std::move(db)) {}
424+
425+
EmbeddingSnapshotHandleWrapper::~EmbeddingSnapshotHandleWrapper() {
426+
db->release_snapshot(handle);
427+
}
428+
389429
KVTensorWrapper::KVTensorWrapper(
390430
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
391431
std::vector<int64_t> shape,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 20 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#endif
2626
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
2727
#include "kv_db_table_batched_embeddings.h"
28+
#include "kv_tensor_wrapper.h"
2829
#include "torch/csrc/autograd/record_function_ops.h"
2930

3031
namespace ssd {
@@ -124,55 +125,29 @@ class Initializer {
124125
std::unique_ptr<std::thread> producer_;
125126
};
126127

128+
class EmbeddingRocksDB;
129+
using snapshot_ptr_t = const rocksdb::Snapshot*;
130+
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
131+
class SnapshotHandle {
132+
public:
133+
explicit SnapshotHandle(EmbeddingRocksDB* db);
134+
~SnapshotHandle();
135+
void release();
136+
snapshot_ptr_t get_snapshot_for_shard(size_t shard) const;
137+
138+
private:
139+
friend class EmbeddingRocksDB;
140+
141+
EmbeddingRocksDB* db_;
142+
std::vector<snapshot_ptr_t> shard_snapshots_;
143+
}; // class SnapshotHandle
144+
127145
/// @ingroup embedding-ssd
128146
///
129147
/// @brief An implementation of EmbeddingKVDB for RocksDB
130148
///
131149
class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
132-
using snapshot_ptr_t = const rocksdb::Snapshot*;
133-
134150
public:
135-
class SnapshotHandle {
136-
public:
137-
explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
138-
auto num_shards = db->num_shards();
139-
CHECK_GT(num_shards, 0);
140-
shard_snapshots_.reserve(num_shards);
141-
for (auto shard = 0; shard < num_shards; ++shard) {
142-
const auto* snapshot = db->dbs_[shard]->GetSnapshot();
143-
CHECK(snapshot != nullptr)
144-
<< "ERROR: create_snapshot fails to create a snapshot "
145-
<< "for db shard " << shard << ". Please make sure that "
146-
<< "inplace_update_support is set to false" << std::endl;
147-
shard_snapshots_.push_back(snapshot);
148-
}
149-
}
150-
151-
~SnapshotHandle() {
152-
for (auto shard = 0; shard < db_->dbs_.size(); ++shard) {
153-
snapshot_ptr_t snapshot = shard_snapshots_[shard];
154-
CHECK(snapshot != nullptr)
155-
<< "Unexpected nullptr for snapshot " << shard;
156-
db_->dbs_[shard]->ReleaseSnapshot(snapshot);
157-
}
158-
}
159-
160-
void release() {
161-
db_->release_snapshot(this);
162-
}
163-
164-
snapshot_ptr_t get_snapshot_for_shard(size_t shard) const {
165-
CHECK_LE(shard, shard_snapshots_.size());
166-
return shard_snapshots_[shard];
167-
}
168-
169-
private:
170-
friend class EmbeddingRocksDB;
171-
172-
EmbeddingRocksDB* db_;
173-
std::vector<snapshot_ptr_t> shard_snapshots_;
174-
};
175-
176151
explicit EmbeddingRocksDB(
177152
std::string path,
178153
int64_t num_shards,
@@ -934,6 +909,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
934909
return folly::collect(futures);
935910
}
936911

912+
friend class SnapshotHandle;
913+
937914
std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
938915
std::vector<std::unique_ptr<Initializer>> initializers_;
939916
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
@@ -960,58 +937,4 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
960937
int64_t elem_size_;
961938
}; // class EmbeddingRocksDB
962939

963-
class EmbeddingRocksDBWrapper;
964-
965-
struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
966-
explicit EmbeddingSnapshotHandleWrapper(
967-
const EmbeddingRocksDB::SnapshotHandle* handle,
968-
std::shared_ptr<EmbeddingRocksDB> db)
969-
: handle(handle), db(std::move(db)) {}
970-
971-
~EmbeddingSnapshotHandleWrapper() {
972-
db->release_snapshot(handle);
973-
}
974-
975-
const EmbeddingRocksDB::SnapshotHandle* handle;
976-
std::shared_ptr<EmbeddingRocksDB> db;
977-
};
978-
979-
class KVTensorWrapper : public torch::jit::CustomClassHolder {
980-
public:
981-
explicit KVTensorWrapper(
982-
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
983-
std::vector<int64_t> shape,
984-
int64_t dtype,
985-
int64_t row_offset,
986-
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
987-
snapshot_handle);
988-
989-
at::Tensor narrow(int64_t dim, int64_t start, int64_t length);
990-
991-
void set_range(
992-
int64_t dim,
993-
const int64_t start,
994-
const int64_t length,
995-
const at::Tensor& weights);
996-
997-
c10::IntArrayRef size();
998-
999-
c10::ScalarType dtype();
1000-
1001-
std::string_view dtype_str();
1002-
1003-
c10::Device device();
1004-
1005-
std::string device_str();
1006-
1007-
std::string layout_str();
1008-
1009-
private:
1010-
std::shared_ptr<EmbeddingRocksDB> db_;
1011-
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
1012-
at::TensorOptions options_;
1013-
std::vector<int64_t> shape_;
1014-
int64_t row_offset_;
1015-
};
1016-
1017940
} // namespace ssd

0 commit comments

Comments
 (0)