Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand tensor accessor checks #3559

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 135 additions & 28 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ struct RestrictPtrTraits {
};
#endif

////////////////////////////////////////////////////////////////////////////////
// *TensorAccessor (For CPU Use)
////////////////////////////////////////////////////////////////////////////////

// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.
// For CUDA tensors it is used in device code (only). This means that we
// restrict ourselves to functions and types available there (e.g.
Expand Down Expand Up @@ -183,16 +187,22 @@ class TensorAccessor<T, 1, PtrTraits, index_t>
strides,
ptr_name,
func_name) {}

C10_HOST_DEVICE T& operator[](index_t i) {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
return this->at(this->strides_[0] * i);
}

C10_HOST_DEVICE const T& operator[](index_t i) const {
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
return this->at(this->strides_[0] * i);
}
};

////////////////////////////////////////////////////////////////////////////////
// *PackedTensorAccessor (For CUDA Use)
////////////////////////////////////////////////////////////////////////////////

// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on
// for CUDA `Tensor`s on the host and as In contrast to `TensorAccessor`s, they
// copy the strides and sizes on instantiation (on the host) in order to
Expand Down Expand Up @@ -487,6 +497,39 @@ inline at::ScalarType scalar_type_for() {
return at::ScalarType::Undefined;
}

////////////////////////////////////////////////////////////////////////////////
// Tensor Checks with Descriptive Error Messages
////////////////////////////////////////////////////////////////////////////////

template <size_t N>
inline void check_tensor_dim(
const at::TensorBase& tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
const char* const func_name,
const char* const tensor_name
#endif
) {
TORCH_CHECK(
tensor.dim() == N,
#ifdef FBGEMM_GPU_MEMCHECK
"[ ",
func_name,
" ]: ",
#endif
"Expected tensor ",
#ifdef FBGEMM_GPU_MEMCHECK
"'",
tensor_name,
"' ",
#endif
"to have ",
N,
"dims, but found ",
tensor.dim(),
" instead!");
}

template <typename T>
inline void check_scalar_type(
const at::TensorBase& tensor
Expand Down Expand Up @@ -528,34 +571,105 @@ namespace pta = fbgemm_gpu;
namespace pta = at;
#endif

template <
typename T,
size_t N,
template <typename U> class PtrTraits = at::DefaultPtrTraits,
typename index_t = int64_t>
inline pta::TensorAccessor<T, N, PtrTraits, index_t> make_tensor_accessor(
#ifdef FBGEMM_GPU_MEMCHECK
const at::Tensor& tensor,
const char* const tensor_name,
const char* const func_name) {
#else
const at::Tensor& tensor) {
#endif

static_assert(
N > 0,
"accessor is used for indexing tensor, for scalars use *data_ptr<T>()");

fbgemm_gpu::check_tensor_dim<N>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
tensor_name
#endif
);

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
tensor_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return fbgemm_gpu::TensorAccessor<T, N, PtrTraits, index_t>(
static_cast<typename PtrTraits<T>::PtrType>(tensor.data_ptr<T>()),
tensor.sizes().data(),
tensor.strides().data(),
tensor_name,
func_name);
#else
return tensor.accessor<T, N>();
#endif
}

////////////////////////////////////////////////////////////////////////////////
// *TensorAccessor Builder Functions
////////////////////////////////////////////////////////////////////////////////

template <
typename T,
size_t N,
template <typename U> class PtrTraits = at::DefaultPtrTraits,
typename index_t = int64_t>
fbgemm_gpu::GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
inline pta::GenericPackedTensorAccessor<T, N, PtrTraits, index_t>
make_generic_packed_tensor_accessor(
#ifdef FBGEMM_GPU_MEMCHECK
const at::Tensor& tensor,
const char* const ptr_name,
const char* const tensor_name,
const char* const func_name) {
#else
const at::Tensor& tensor) {
#endif
static_assert(
N > 0,
"accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
TORCH_CHECK(
tensor.dim() == N,
"TensorAccessor expected ",
N,
" dims but tensor has ",
tensor.dim());

fbgemm_gpu::check_tensor_dim<N>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
tensor_name
#endif
);

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
tensor_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return fbgemm_gpu::GenericPackedTensorAccessor<T, N, PtrTraits, index_t>(
static_cast<typename PtrTraits<T>::PtrType>(tensor.data_ptr<T>()),
tensor.sizes().data(),
tensor.strides().data(),
ptr_name,
tensor_name,
func_name);
}
#else
return tensor.generic_packed_accessor<T, N, PtrTraits, index_t>();
#endif
}

template <
typename T,
Expand All @@ -575,15 +689,6 @@ pta::PackedTensorAccessor32<T, N, PtrTraits> make_packed_tensor_accessor32(
static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
"numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
ptr_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return make_generic_packed_tensor_accessor<T, N, PtrTraits, int32_t>(
tensor, ptr_name, func_name);
Expand All @@ -605,15 +710,6 @@ pta::PackedTensorAccessor64<T, N, PtrTraits> make_packed_tensor_accessor64(
const at::Tensor& tensor) {
#endif

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
ptr_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return make_generic_packed_tensor_accessor<T, N, PtrTraits, int64_t>(
tensor, ptr_name, func_name);
Expand All @@ -622,7 +718,14 @@ pta::PackedTensorAccessor64<T, N, PtrTraits> make_packed_tensor_accessor64(
#endif
}

////////////////////////////////////////////////////////////////////////////////
// *TensorAccessor Builder Macros
////////////////////////////////////////////////////////////////////////////////

#ifdef FBGEMM_GPU_MEMCHECK
#define MAKE_TA_WITH_NAME(FUNC_NAME, TENSOR, T, N) \
make_tensor_accessor<T, N>(TENSOR, #TENSOR, FUNC_NAME)

#define MAKE_PACKED_TENSOR_ACCESSOR_BASE( \
FUNC_NAME, TENSOR, T, N, PTR_TRAITS, INDEX_NBITS) \
make_packed_tensor_accessor##INDEX_NBITS<T, N, PTR_TRAITS>( \
Expand All @@ -634,7 +737,11 @@ pta::PackedTensorAccessor64<T, N, PtrTraits> make_packed_tensor_accessor64(
at::acc_type<T, true>, \
N, \
PTR_TRAITS>(TENSOR, #TENSOR, FUNC_NAME)

#else
#define MAKE_TA_WITH_NAME(FUNC_NAME, TENSOR, T, N) \
make_tensor_accessor<T, N>(TENSOR)

#define MAKE_PACKED_TENSOR_ACCESSOR_BASE( \
FUNC_NAME, TENSOR, T, N, PTR_TRAITS, INDEX_NBITS) \
make_packed_tensor_accessor##INDEX_NBITS<T, N, PTR_TRAITS>(TENSOR)
Expand Down
Loading