diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h new file mode 100644 index 0000000000..c25d041217 --- /dev/null +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include // @manual=//caffe2:ATen-core +#include + +namespace ssd { + +class EmbeddingRocksDB; +class EmbeddingRocksDBWrapper; +class SnapshotHandle; + +// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions +struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder { + explicit EmbeddingSnapshotHandleWrapper( + const SnapshotHandle* handle, + std::shared_ptr db); + + ~EmbeddingSnapshotHandleWrapper(); + + const SnapshotHandle* handle; + std::shared_ptr db; +}; + +class KVTensorWrapper : public torch::jit::CustomClassHolder { + public: + explicit KVTensorWrapper( + c10::intrusive_ptr db, + std::vector shape, + int64_t dtype, + int64_t row_offset, + std::optional> + snapshot_handle); + + at::Tensor narrow(int64_t dim, int64_t start, int64_t length); + + void set_range( + int64_t dim, + const int64_t start, + const int64_t length, + const at::Tensor& weights); + + c10::IntArrayRef size(); + + c10::ScalarType dtype(); + + std::string_view dtype_str(); + + c10::Device device(); + + std::string device_str(); + + std::string layout_str(); + + private: + std::shared_ptr db_; + c10::intrusive_ptr snapshot_handle_; + at::TensorOptions options_; + std::vector shape_; + int64_t row_offset_; +}; + +} // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp new file mode 100644 index 0000000000..4ae614f4de --- /dev/null +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper_cpu.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "./kv_tensor_wrapper.h" +#include "common/base/Exception.h" + +using namespace at; +using namespace ssd; + +namespace ssd { +class EmbeddingRocksDB {}; + +// @lint-ignore CLANGTIDY facebook-hte-ShadowingClass +class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { + private: + friend class KVTensorWrapper; + std::shared_ptr impl_; +}; + +class SnapshotHandle {}; + +KVTensorWrapper::KVTensorWrapper( + c10::intrusive_ptr db, + std::vector shape, + [[maybe_unused]] int64_t dtype, + int64_t row_offset, + [[maybe_unused]] std::optional< + c10::intrusive_ptr> snapshot_handle) + // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn + : db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) { + FBEXCEPTION("Not implemented"); +} + +at::Tensor KVTensorWrapper::narrow( + [[maybe_unused]] int64_t dim, + [[maybe_unused]] int64_t start, + [[maybe_unused]] int64_t length) { + FBEXCEPTION("Not implemented"); + return at::empty(c10::IntArrayRef({1, 1}), options_); +} + +void KVTensorWrapper::set_range( + [[maybe_unused]] int64_t dim, + [[maybe_unused]] const int64_t start, + [[maybe_unused]] const int64_t length, + // @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn + [[maybe_unused]] const at::Tensor& weights) { + FBEXCEPTION("Not implemented"); +} + +c10::IntArrayRef KVTensorWrapper::size() { + FBEXCEPTION("Not implemented"); + return shape_; +} + +c10::ScalarType KVTensorWrapper::dtype() { + FBEXCEPTION("Not implemented"); + return options_.dtype().toScalarType(); +} + +std::string_view KVTensorWrapper::dtype_str() { + FBEXCEPTION("Not implemented"); + return scalarTypeToTypeMeta(dtype()).name(); +} + +c10::Device KVTensorWrapper::device() { + FBEXCEPTION("Not implemented"); + return options_.device(); +} + +std::string KVTensorWrapper::device_str() { + FBEXCEPTION("Not implemented"); + return device().str(); +} + +std::string KVTensorWrapper::layout_str() { + FBEXCEPTION("Not implemented"); + std::ostringstream oss; + oss << options_.layout(); + return oss.str(); +} +} // namespace ssd diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index c45380a9e6..45d4d2e34d 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -386,6 +386,46 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { std::shared_ptr impl_; }; +SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { + auto num_shards = db->num_shards(); + CHECK_GT(num_shards, 0); + shard_snapshots_.reserve(num_shards); + for (auto shard = 0; shard < num_shards; ++shard) { + const auto* snapshot = db->dbs_[shard]->GetSnapshot(); + CHECK(snapshot != nullptr) + << "ERROR: create_snapshot fails to create a snapshot " + << "for db shard " << shard << ". Please make sure that " + << "inplace_update_support is set to false" << std::endl; + shard_snapshots_.push_back(snapshot); + } +} + +SnapshotHandle::~SnapshotHandle() { + for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { + snapshot_ptr_t snapshot = shard_snapshots_[shard]; + CHECK(snapshot != nullptr) << "Unexpected nullptr for snapshot " << shard; + db_->dbs_[shard]->ReleaseSnapshot(snapshot); + } +} + +void SnapshotHandle::release() { + db_->release_snapshot(this); +} + +snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const { + CHECK_LE(shard, shard_snapshots_.size()); + return shard_snapshots_[shard]; +} + +EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper( + const SnapshotHandle* handle, + std::shared_ptr db) + : handle(handle), db(std::move(db)) {} + +EmbeddingSnapshotHandleWrapper::~EmbeddingSnapshotHandleWrapper() { + db->release_snapshot(handle); +} + KVTensorWrapper::KVTensorWrapper( c10::intrusive_ptr db, std::vector shape, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 5bb7358cf1..a84ac16981 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -25,6 +25,7 @@ #endif #include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h" #include "kv_db_table_batched_embeddings.h" +#include "kv_tensor_wrapper.h" #include "torch/csrc/autograd/record_function_ops.h" namespace ssd { @@ -124,55 +125,29 @@ class Initializer { std::unique_ptr producer_; }; +class EmbeddingRocksDB; +using snapshot_ptr_t = const rocksdb::Snapshot*; +// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions +class SnapshotHandle { + public: + explicit SnapshotHandle(EmbeddingRocksDB* db); + ~SnapshotHandle(); + void release(); + snapshot_ptr_t get_snapshot_for_shard(size_t shard) const; + + private: + friend class EmbeddingRocksDB; + + EmbeddingRocksDB* db_; + std::vector shard_snapshots_; +}; // class SnapshotHandle + /// @ingroup embedding-ssd /// /// @brief An implementation of EmbeddingKVDB for RocksDB /// class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { - using snapshot_ptr_t = const rocksdb::Snapshot*; - public: - class SnapshotHandle { - public: - explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { - auto num_shards = db->num_shards(); - CHECK_GT(num_shards, 0); - shard_snapshots_.reserve(num_shards); - for (auto shard = 0; shard < num_shards; ++shard) { - const auto* snapshot = db->dbs_[shard]->GetSnapshot(); - CHECK(snapshot != nullptr) - << "ERROR: create_snapshot fails to create a snapshot " - << "for db shard " << shard << ". Please make sure that " - << "inplace_update_support is set to false" << std::endl; - shard_snapshots_.push_back(snapshot); - } - } - - ~SnapshotHandle() { - for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { - snapshot_ptr_t snapshot = shard_snapshots_[shard]; - CHECK(snapshot != nullptr) - << "Unexpected nullptr for snapshot " << shard; - db_->dbs_[shard]->ReleaseSnapshot(snapshot); - } - } - - void release() { - db_->release_snapshot(this); - } - - snapshot_ptr_t get_snapshot_for_shard(size_t shard) const { - CHECK_LE(shard, shard_snapshots_.size()); - return shard_snapshots_[shard]; - } - - private: - friend class EmbeddingRocksDB; - - EmbeddingRocksDB* db_; - std::vector shard_snapshots_; - }; - explicit EmbeddingRocksDB( std::string path, int64_t num_shards, @@ -934,6 +909,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { return folly::collect(futures); } + friend class SnapshotHandle; + std::vector> dbs_; std::vector> initializers_; std::unique_ptr executor_; @@ -960,58 +937,4 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { int64_t elem_size_; }; // class EmbeddingRocksDB -class EmbeddingRocksDBWrapper; - -struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder { - explicit EmbeddingSnapshotHandleWrapper( - const EmbeddingRocksDB::SnapshotHandle* handle, - std::shared_ptr db) - : handle(handle), db(std::move(db)) {} - - ~EmbeddingSnapshotHandleWrapper() { - db->release_snapshot(handle); - } - - const EmbeddingRocksDB::SnapshotHandle* handle; - std::shared_ptr db; -}; - -class KVTensorWrapper : public torch::jit::CustomClassHolder { - public: - explicit KVTensorWrapper( - c10::intrusive_ptr db, - std::vector shape, - int64_t dtype, - int64_t row_offset, - std::optional> - snapshot_handle); - - at::Tensor narrow(int64_t dim, int64_t start, int64_t length); - - void set_range( - int64_t dim, - const int64_t start, - const int64_t length, - const at::Tensor& weights); - - c10::IntArrayRef size(); - - c10::ScalarType dtype(); - - std::string_view dtype_str(); - - c10::Device device(); - - std::string device_str(); - - std::string layout_str(); - - private: - std::shared_ptr db_; - c10::intrusive_ptr snapshot_handle_; - at::TensorOptions options_; - std::vector shape_; - int64_t row_offset_; -}; - } // namespace ssd