Skip to content

Commit

Permalink
put KVTensorWrapper in its own header (#3575)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3575

X-link: facebookresearch/FBGEMM#660

some consumers of KVTensorWrapper build cpu-only packages. this diff made the following changes to avoid linking against cuda libraries:
- put KVTensorWrapper in its own header file
- add a dummy cpu target for KVTensorWrapper

Reviewed By: q10, sryap

Differential Revision: D68060586

fbshipit-source-id: 3fb4ade32108d557d2e1d19b629449867f0f0e7b
  • Loading branch information
Yulu Jia authored and facebook-github-bot committed Jan 16, 2025
1 parent 9e9aa93 commit ded03b8
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 97 deletions.
70 changes: 70 additions & 0 deletions fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <ATen/Tensor.h> // @manual=//caffe2:ATen-core
#include <torch/custom_class.h>

namespace ssd {

class EmbeddingRocksDB;
class EmbeddingRocksDBWrapper;
class SnapshotHandle;

// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
explicit EmbeddingSnapshotHandleWrapper(
const SnapshotHandle* handle,
std::shared_ptr<EmbeddingRocksDB> db);

~EmbeddingSnapshotHandleWrapper();

const SnapshotHandle* handle;
std::shared_ptr<EmbeddingRocksDB> db;
};

class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
int64_t dtype,
int64_t row_offset,
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle);

at::Tensor narrow(int64_t dim, int64_t start, int64_t length);

void set_range(
int64_t dim,
const int64_t start,
const int64_t length,
const at::Tensor& weights);

c10::IntArrayRef size();

c10::ScalarType dtype();

std::string_view dtype_str();

c10::Device device();

std::string device_str();

std::string layout_str();

private:
std::shared_ptr<EmbeddingRocksDB> db_;
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
at::TensorOptions options_;
std::vector<int64_t> shape_;
int64_t row_offset_;
};

} // namespace ssd
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <torch/library.h>

#include "./kv_tensor_wrapper.h"
#include "common/base/Exception.h"

using namespace at;
using namespace ssd;

namespace ssd {
class EmbeddingRocksDB {};

// @lint-ignore CLANGTIDY facebook-hte-ShadowingClass
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
private:
friend class KVTensorWrapper;
std::shared_ptr<EmbeddingRocksDB> impl_;
};

class SnapshotHandle {};

KVTensorWrapper::KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
[[maybe_unused]] int64_t dtype,
int64_t row_offset,
[[maybe_unused]] std::optional<
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>> snapshot_handle)
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
: db_(db->impl_), shape_(std::move(shape)), row_offset_(row_offset) {
FBEXCEPTION("Not implemented");
}

at::Tensor KVTensorWrapper::narrow(
[[maybe_unused]] int64_t dim,
[[maybe_unused]] int64_t start,
[[maybe_unused]] int64_t length) {
FBEXCEPTION("Not implemented");
return at::empty(c10::IntArrayRef({1, 1}), options_);
}

void KVTensorWrapper::set_range(
[[maybe_unused]] int64_t dim,
[[maybe_unused]] const int64_t start,
[[maybe_unused]] const int64_t length,
// @lint-ignore CLANGTIDY clang-diagnostic-missing-noreturn
[[maybe_unused]] const at::Tensor& weights) {
FBEXCEPTION("Not implemented");
}

c10::IntArrayRef KVTensorWrapper::size() {
FBEXCEPTION("Not implemented");
return shape_;
}

c10::ScalarType KVTensorWrapper::dtype() {
FBEXCEPTION("Not implemented");
return options_.dtype().toScalarType();
}

std::string_view KVTensorWrapper::dtype_str() {
FBEXCEPTION("Not implemented");
return scalarTypeToTypeMeta(dtype()).name();
}

c10::Device KVTensorWrapper::device() {
FBEXCEPTION("Not implemented");
return options_.device();
}

std::string KVTensorWrapper::device_str() {
FBEXCEPTION("Not implemented");
return device().str();
}

std::string KVTensorWrapper::layout_str() {
FBEXCEPTION("Not implemented");
std::ostringstream oss;
oss << options_.layout();
return oss.str();
}
} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,46 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
std::shared_ptr<ssd::EmbeddingRocksDB> impl_;
};

SnapshotHandle::SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
auto num_shards = db->num_shards();
CHECK_GT(num_shards, 0);
shard_snapshots_.reserve(num_shards);
for (auto shard = 0; shard < num_shards; ++shard) {
const auto* snapshot = db->dbs_[shard]->GetSnapshot();
CHECK(snapshot != nullptr)
<< "ERROR: create_snapshot fails to create a snapshot "
<< "for db shard " << shard << ". Please make sure that "
<< "inplace_update_support is set to false" << std::endl;
shard_snapshots_.push_back(snapshot);
}
}

SnapshotHandle::~SnapshotHandle() {
for (auto shard = 0; shard < db_->dbs_.size(); ++shard) {
snapshot_ptr_t snapshot = shard_snapshots_[shard];
CHECK(snapshot != nullptr) << "Unexpected nullptr for snapshot " << shard;
db_->dbs_[shard]->ReleaseSnapshot(snapshot);
}
}

void SnapshotHandle::release() {
db_->release_snapshot(this);
}

snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
CHECK_LE(shard, shard_snapshots_.size());
return shard_snapshots_[shard];
}

EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
const SnapshotHandle* handle,
std::shared_ptr<EmbeddingRocksDB> db)
: handle(handle), db(std::move(db)) {}

EmbeddingSnapshotHandleWrapper::~EmbeddingSnapshotHandleWrapper() {
db->release_snapshot(handle);
}

KVTensorWrapper::KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#endif
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
#include "kv_db_table_batched_embeddings.h"
#include "kv_tensor_wrapper.h"
#include "torch/csrc/autograd/record_function_ops.h"

namespace ssd {
Expand Down Expand Up @@ -124,55 +125,29 @@ class Initializer {
std::unique_ptr<std::thread> producer_;
};

class EmbeddingRocksDB;
using snapshot_ptr_t = const rocksdb::Snapshot*;
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
class SnapshotHandle {
public:
explicit SnapshotHandle(EmbeddingRocksDB* db);
~SnapshotHandle();
void release();
snapshot_ptr_t get_snapshot_for_shard(size_t shard) const;

private:
friend class EmbeddingRocksDB;

EmbeddingRocksDB* db_;
std::vector<snapshot_ptr_t> shard_snapshots_;
}; // class SnapshotHandle

/// @ingroup embedding-ssd
///
/// @brief An implementation of EmbeddingKVDB for RocksDB
///
class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
using snapshot_ptr_t = const rocksdb::Snapshot*;

public:
class SnapshotHandle {
public:
explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) {
auto num_shards = db->num_shards();
CHECK_GT(num_shards, 0);
shard_snapshots_.reserve(num_shards);
for (auto shard = 0; shard < num_shards; ++shard) {
const auto* snapshot = db->dbs_[shard]->GetSnapshot();
CHECK(snapshot != nullptr)
<< "ERROR: create_snapshot fails to create a snapshot "
<< "for db shard " << shard << ". Please make sure that "
<< "inplace_update_support is set to false" << std::endl;
shard_snapshots_.push_back(snapshot);
}
}

~SnapshotHandle() {
for (auto shard = 0; shard < db_->dbs_.size(); ++shard) {
snapshot_ptr_t snapshot = shard_snapshots_[shard];
CHECK(snapshot != nullptr)
<< "Unexpected nullptr for snapshot " << shard;
db_->dbs_[shard]->ReleaseSnapshot(snapshot);
}
}

void release() {
db_->release_snapshot(this);
}

snapshot_ptr_t get_snapshot_for_shard(size_t shard) const {
CHECK_LE(shard, shard_snapshots_.size());
return shard_snapshots_[shard];
}

private:
friend class EmbeddingRocksDB;

EmbeddingRocksDB* db_;
std::vector<snapshot_ptr_t> shard_snapshots_;
};

explicit EmbeddingRocksDB(
std::string path,
int64_t num_shards,
Expand Down Expand Up @@ -934,6 +909,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
return folly::collect(futures);
}

friend class SnapshotHandle;

std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
std::vector<std::unique_ptr<Initializer>> initializers_;
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
Expand All @@ -960,58 +937,4 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
int64_t elem_size_;
}; // class EmbeddingRocksDB

class EmbeddingRocksDBWrapper;

struct EmbeddingSnapshotHandleWrapper : public torch::jit::CustomClassHolder {
explicit EmbeddingSnapshotHandleWrapper(
const EmbeddingRocksDB::SnapshotHandle* handle,
std::shared_ptr<EmbeddingRocksDB> db)
: handle(handle), db(std::move(db)) {}

~EmbeddingSnapshotHandleWrapper() {
db->release_snapshot(handle);
}

const EmbeddingRocksDB::SnapshotHandle* handle;
std::shared_ptr<EmbeddingRocksDB> db;
};

class KVTensorWrapper : public torch::jit::CustomClassHolder {
public:
explicit KVTensorWrapper(
c10::intrusive_ptr<EmbeddingRocksDBWrapper> db,
std::vector<int64_t> shape,
int64_t dtype,
int64_t row_offset,
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
snapshot_handle);

at::Tensor narrow(int64_t dim, int64_t start, int64_t length);

void set_range(
int64_t dim,
const int64_t start,
const int64_t length,
const at::Tensor& weights);

c10::IntArrayRef size();

c10::ScalarType dtype();

std::string_view dtype_str();

c10::Device device();

std::string device_str();

std::string layout_str();

private:
std::shared_ptr<EmbeddingRocksDB> db_;
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle_;
at::TensorOptions options_;
std::vector<int64_t> shape_;
int64_t row_offset_;
};

} // namespace ssd

0 comments on commit ded03b8

Please sign in to comment.