diff --git a/fbgemm_gpu/cmake/TbeTraining.cmake b/fbgemm_gpu/cmake/TbeTraining.cmake index b02c7b582a..866b481e67 100644 --- a/fbgemm_gpu/cmake/TbeTraining.cmake +++ b/fbgemm_gpu/cmake/TbeTraining.cmake @@ -42,14 +42,6 @@ handle_genfiles(gen_py_files_training) handle_genfiles(gen_py_files_defused_optim) -################################################################################ -# FBGEMM_GPU Generated HIP-Specific Sources -################################################################################ - -get_tbe_sources_list(gen_hip_files_training) -handle_genfiles_rocm(gen_hip_files_training) - - ################################################################################ # TBE C++ Training Targets ################################################################################ @@ -160,8 +152,6 @@ gpu_cpp_library( ${gen_cpu_files_training} GPU_SRCS ${gen_gpu_files_training} - HIP_SPECIFIC_SRCS - ${gen_hip_files_training} GPU_FLAGS ${TORCH_CUDA_OPTIONS} DEPS diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 2c0504c9ad..99d6ff302e 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -473,25 +473,6 @@ ] ) -gen_hip_files_training = [ - "gen_embedding_backward_split_{}{}_device_kernel_hip.hip".format( - "weighted" if weighted else "unweighted", - "_nobag" if nobag else "", - ) - for nobag in [ - True, - False, - ] - for weighted in ( - [ - True, - False, - ] - if not nobag - else [False] - ) -] - ################################################################################ # Python Training Code ################################################################################ diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index c977148578..ac60a8dad8 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -310,27 +310,6 @@ def generate_backward_indices() -> None: ssd=ssd, ) - @staticmethod - def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) - template_filepath = ( - "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" - ) - - BackwardSplitGenerator.render_backward_templates( - template_filepath, - "", - "{}gen_embedding_backward_{}_device_kernel_hip.hip", - { - "has_gpu_support": True, - "has_vbe_support": False, - "has_ssd_support": False, - "dense": False, - "gen_once": False, - }, - ) - @staticmethod def generate_python_sources( all_optimizers: List[str], ssd_optimizers: List[str] @@ -390,7 +369,6 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_split( ssd_tensors=ssd_tensors, **optimizer ) - BackwardSplitGenerator.generate_rocm_backward_split() # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..ee608e83e0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -171,8 +171,7 @@ Tensor split_embedding_codegen_lookup_dense_function( Tensor>& /* vbe_B_offsets_rank_per_feature = std::nullopt */, c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, - c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + c10::SymInt /* vbe_output_size = -1 */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, @@ -191,7 +190,7 @@ Tensor split_embedding_codegen_lookup_dense_function( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); @@ -199,7 +198,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( - "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1, bool mixed_D=True) -> Tensor"); + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, SymInt total_D, SymInt max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, SymInt max_B=-1, SymInt max_B_feature_rank=-1, SymInt vbe_output_size=-1) -> Tensor"); DISPATCH_TO_CPU( "dense_embedding_codegen_lookup_function", split_embedding_codegen_lookup_dense_function); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 3f76798d28..211f42309e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -152,7 +152,6 @@ enum SSDTensor { {%- else %} D_offsets, max_D, - mixed_D, {%- endif %} {# /* if nobag */ #} hash_size_cumsum, total_hash_size_bits, @@ -225,7 +224,6 @@ enum SSDTensor { Variable(), // D_offsets Variable(), // total_D Variable(), // max_D - Variable(), // mixed_D {%- endif %} Variable(), // hash_size_cumsum Variable(), //total_hash_size_bits @@ -306,7 +304,6 @@ enum SSDTensor { D_offsets, total_D, max_D, - mixed_D, {%- endif %} hash_size_cumsum, total_hash_size_bits, @@ -487,7 +484,6 @@ Tensor {%- else %} const Tensor& D_offsets, const c10::SymInt max_D, - const bool mixed_D, {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, @@ -570,7 +566,6 @@ class {{ autograd_func }} : const Tensor& D_offsets, const c10::SymInt total_D, const c10::SymInt max_D, - const bool mixed_D, {%- else %} const c10::SymInt D, {%- endif %} @@ -767,7 +762,6 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -883,7 +877,6 @@ class {{ autograd_func }} : {%- if not nobag %} auto max_D = ctx->saved_data["max_D"].toSymInt(); - const auto mixed_D = ctx->saved_data["mixed_D"].toBool(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); {%- else %} auto D = ctx->saved_data["D"].toSymInt(); @@ -1079,11 +1072,10 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- if ssd %} const std::optional& ssd_tensors = std::nullopt, {%- endif %} - const double gwd_lower_bound = 0, + const double gwd_lower_bound = 0 {%- else %} - const c10::SymInt vbe_output_size = -1, + const c10::SymInt vbe_output_size = -1 {%- endif %} - const bool mixed_D = true ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -1199,8 +1191,7 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { {%- if ssd %} " Tensor[]? ssd_tensors=None," {%- endif %} - " float gwd_lower_bound=0, " - " bool mixed_D=True" + " float gwd_lower_bound=0 " ") -> Tensor", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1732239db2..bc27f15281 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -527,321 +527,5 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row #endif //////////////////////////////////////////////////////////////////////////////// -{%- endif %} - -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} -#include -#include -#include "fbgemm_gpu/rocm/split_embeddings_common.h" -#include "gen_embedding_backward_split_{{ desc_suffix }}{{ ndesc }}_device_kernel_hip.hip" - -template < - typename emb_t, - typename grad_t, - typename cache_t, - typename index_t, - int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, - bool kUseVecBlocking, - int32_t embedding_dim, - int32_t weight_decay_mode_v> -__global__ void -hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( - const pta::PackedTensorAccessor64 grad_output, - {%- if optimizer != "none" %} - pta::PackedTensorAccessor64 dev_weights, - {%- if not dense %} - pta::PackedTensorAccessor64 uvm_weights, - pta::PackedTensorAccessor64 lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag or is_index_select %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, - const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - {%- if not nobag %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- else %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, - const bool use_uniq_cache_locations, - const pta::PackedTensorAccessor32 table_unique_indices_offsets, - {%- endif %} - {%- if weighted %} - const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, - {%- endif %} - const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, - int32_t max_segment_length_per_warp, - {%- if not dense and optimizer != "none" %} - bool stochastic_rounding, - at::PhiloxCudaState stochastic_rounding_philox_args, - {%- else %} - pta::PackedTensorAccessor64 grad_dev_weights, - {%- endif %} // if not dense and optimizer != "none" - {%- if not nobag and vbe %} - const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 row_output_offsets, - {%- endif %} - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - const int32_t max_D, - const int32_t max_vecs_per_thread, - {%- if is_index_select %} - const at::PackedTensorAccessor32 grad_offsets, - const bool permute_output_dim_0_1 - {%- else %} - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} - {%- endif %} -) { - {%- if not nobag %} - int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - - auto p_output_grad = grad_output.data(); - auto p_emb_table = dev_weights.data(); - auto p_hash_size_cumsum = hash_size_cumsum.data(); - auto p_sorted_linear_indices_run = sorted_linear_indices_run.data(); - auto p_sorted_linear_indices_cumulative_run_lengths = sorted_linear_indices_cumulative_run_lengths.data(); - auto p_sorted_linear_indices_num_runs = sorted_linear_indices_num_runs.data(); - auto p_sorted_infos = sorted_infos.data(); - {%- if weighted %} - auto p_indice_weights_sorted = sorted_indice_weights.data(); - {%- endif %} - auto emb_dim = embedding_dim; - constexpr int32_t segment_prefetch = 2; - constexpr int32_t segment_unroll = 8; - constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; - {%- if weighted %} - constexpr bool is_weighted = true; - {%- else %} - constexpr bool is_weighted = false; - {%- endif %} - rocm::{{optimizer}}_kernel_arg_t opt_karg; - opt_karg.p_momentum = momentum1_dev.data(); - opt_karg.eps = eps; - opt_karg.learning_rate = learning_rate; - // weight_decay(_mode) is supplied as args.split_function_args_no_defaults - opt_karg.weight_decay_mode = weight_decay_mode_v; - opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); - rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, - rocm::{{optimizer}}_kernel_arg_t, - emb_t, - cache_t, - grad_t, - index_t, - BLOCK_SIZE, - embedding_dim, - segment_prefetch, - segment_unroll, - segment_split, - is_weighted>(p_output_grad, - p_emb_table, - p_hash_size_cumsum, - p_sorted_linear_indices_run, - p_sorted_linear_indices_cumulative_run_lengths, - p_sorted_linear_indices_num_runs, - {%- if not nobag %} - info_B_num_bits, - info_B_mask, - {%- endif %} - p_sorted_infos, - batch_mdiv, - max_segment_length_per_warp, - emb_dim, - batch, - num_rows, - T, - opt_karg - {%- if weighted %} - , p_indice_weights_sorted - {%- endif %}); -} - -{%- macro hip_template_instantiation( - emb_type, - grad_type, - cache_type, - index_type, - kFixedMaxVecsPerThread, - kThreadGroupSize, - kUseVecBlocking, - kEmbeddingDim, - kWeighDecayMode - ) -%} -template __global__ __launch_bounds__(kBackwardMaxThreads) void -hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 -< {{ emb_type }}, - {{ grad_type }}, - {{ cache_type }}, - {{ index_type }}, - {{ kFixedMaxVecsPerThread }}, - {{ kThreadGroupSize }}, - {{ kUseVecBlocking }}, - {{ kEmbeddingDim }}, - {{ kWeighDecayMode }} -> ( - const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, - {%- if optimizer != "none" %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, - {%- if not dense %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, - pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag or is_index_select %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, - const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - {%- if not nobag %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- else %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, - const bool use_uniq_cache_locations, - const pta::PackedTensorAccessor32 table_unique_indices_offsets, - {%- endif %} - {%- if weighted %} - const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, - {%- endif %} - const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, - int32_t max_segment_length_per_warp, - {%- if not dense and optimizer != "none" %} - bool stochastic_rounding, - at::PhiloxCudaState stochastic_rounding_philox_args, - {%- else %} - pta::PackedTensorAccessor64< {{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, - {%- endif %} // if not dense and optimizer != "none" - {%- if not nobag and vbe %} - const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 row_output_offsets, - {%- endif %} - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - const int32_t max_D, - const int32_t max_vecs_per_thread, - {%- if is_index_select %} - const at::PackedTensorAccessor32 grad_offsets, - const bool permute_output_dim_0_1 - {%- else %} - {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }} - {%- endif %} -); -{%- endmacro %} - -{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} - {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for emb_type in ['float', 'at::Half'] %} - {%- for cache_type in ['float', 'at::Half'] %} - {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} - {%- for kWeighDecayMode in [0, 1, 2] %} - {{ hip_template_instantiation( - emb_type, - grad_type, - cache_type, - index_type, - kFixedMaxVecsPerThread, - kThreadGroupSize, - kUseVecBlocking, - kEmbeddingDim, - kWeighDecayMode - ) - }} - {%- endfor %} - {%- endfor %} - {%- endfor %} - {%- endfor %} - {%- endfor %} - {%- endfor %} -{%- endmacro %} - -{%- macro hip_instantiate_templates(use_subwarp_shuffle) %} -{%- for (kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) - in get_max_vecs_template_configs( - items_per_warp, - fixed_max_vecs_per_thread["backward"], - use_subwarp_shuffle, - use_vec_blocking=True, - ) -%} - {{ - hip_bulk_template_instantiations( - kFixedMaxVecsPerThread, - kThreadGroupSize, - kUseVecBlocking, - ) - }} -{%- endfor %} -{%- endmacro %} - -//////////////////////////////////////////////////////////////////////////////// -#ifdef FBGEMM_USE_SUBWARP_SHUFFLE -//////////////////////////////////////////////////////////////////////////////// - -{#- /* - Explicitly instantiate kernels for the FBGEMM_USE_SUBWARP_SHUFFLE case - Please see get_max_vecs_template_configs in - codegen/embedding_common_code_generator.py for more details -*/ #} - -{{ hip_instantiate_templates(use_subwarp_shuffle=True) }} - -//////////////////////////////////////////////////////////////////////////////// -#else -//////////////////////////////////////////////////////////////////////////////// - -{#- /* - Explicitly instantiate kernels for the non-FBGEMM_USE_SUBWARP_SHUFFLE case - Please see get_max_vecs_template_configs in - codegen/embedding_common_code_generator.py for more details -*/ #} - -{{ hip_instantiate_templates(use_subwarp_shuffle=False) }} - -//////////////////////////////////////////////////////////////////////////////// -#endif -//////////////////////////////////////////////////////////////////////////////// {%- endif %} // clang-format on diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp index def21bd39d..6b3d5604d1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp @@ -72,9 +72,6 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc {%- else %} const c10::SymInt D, {%- endif %} - {%- if not nobag and not is_index_select %} - const bool mixed_D, - {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 17365822e8..736131ef4c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -26,10 +26,6 @@ #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "fbgemm_gpu/utils/ops_utils.h" -{%- if is_rocm %} -#include "fbgemm_gpu/rocm/cdna_guard.h" -{%- endif %} - using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -218,78 +214,6 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} -#include "fbgemm_gpu/rocm/split_embeddings_common.h" -template < - typename emb_t, - typename grad_t, - typename cache_t, - typename index_t, - int32_t kFixedMaxVecsPerThread, - int32_t kThreadGroupSize, - bool kUseVecBlocking, - int32_t embedding_dim, - int32_t weight_decay_mode_v> -__global__ void -hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( - const pta::PackedTensorAccessor64 grad_output, - {%- if optimizer != "none" %} - pta::PackedTensorAccessor64 dev_weights, - {%- if not dense %} - pta::PackedTensorAccessor64 uvm_weights, - pta::PackedTensorAccessor64 lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag or is_index_select %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, - const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - {%- if not nobag %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- else %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, - const bool use_uniq_cache_locations, - const pta::PackedTensorAccessor32 table_unique_indices_offsets, - {%- endif %} - {%- if weighted %} - const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, - {%- endif %} - const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, - int32_t max_segment_length_per_warp, - {%- if not dense and optimizer != "none" %} - bool stochastic_rounding, - at::PhiloxCudaState stochastic_rounding_philox_args, - {%- else %} - pta::PackedTensorAccessor64 grad_dev_weights, - {%- endif %} // if not dense and optimizer != "none" - {%- if not nobag and vbe %} - const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 row_output_offsets, - {%- endif %} - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - const int32_t max_D, - const int32_t max_vecs_per_thread, - {%- if is_index_select %} - const at::PackedTensorAccessor32 grad_offsets, - const bool permute_output_dim_0_1 - {%- else %} - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} - {%- endif %} -); -{%- endif %} {% if is_index_select %} namespace index_select { {% else %} @@ -531,9 +455,6 @@ Tensor {{ embedding_cuda_op }}( {%- else %} const c10::SymInt D_, {%- endif %} - {%- if not nobag and not is_index_select %} - const bool mixed_D, - {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -861,17 +782,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} - {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( - ndesc, - optimizer, - wdesc, - vdesc, - ) - %} - {%- endif %} - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), @@ -1171,7 +1081,7 @@ Tensor {{ embedding_cuda_op }}( desc_suffix, ) %} - auto backward_warp_per_row_kernel = + const auto backward_warp_per_row_kernel = {{ warp_kernel }} (), segments_per_workgroup); - blockSize = dim3(256); - warp_per_row_smem_bytes = 0; - - backward_warp_per_row_kernel = - {{ hip_kernel }} - ; - } - {%- endfor %} - {%- endfor %} - } - {%- endif %} -#endif - - #ifdef FBGEMM_GPU_MEMCHECK const auto func_name4 = "{{ warp_kernel }}"; #endif backward_warp_per_row_kernel <<>>( grad_output_accessor, @@ -1363,9 +1235,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- else %} " SymInt D, " {%- endif %} - {%- if not nobag and not is_index_select %} - " bool mixed_D, " - {%- endif %} " Tensor hash_size_cumsum, " " int total_hash_size_bits, " " Tensor indices, " diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip deleted file mode 100644 index 2fcbba395e..0000000000 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ /dev/null @@ -1,462 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2016 - 2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - ******************************************************************************/ - -#include -#include - -#include "fbgemm_gpu/rocm/split_embeddings_common.h" - -namespace fbgemm_gpu::rocm { -template -struct rowwise_adagrad_optimizer_t -{ - __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) - : karg(karg_) - { - } - - template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) - { - if constexpr(segment_split == 0) - { - cache_t * p_momentum = reinterpret_cast(karg.p_momentum); - cache_t momentum = p_momentum[row_index]; // should be s_load - // compute per row square sum - cache_t local_sum_squre = .0f; - if constexpr(weight_decay_mode == 1) - { -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t w = static_cast(weight[i]); - cache_t a = acc[i] + w * karg.weight_decay; - local_sum_squre += a * a; - } - } - else - { -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t a = acc[i]; - local_sum_squre += a * a; - } - } - - cache_t avg_square = - wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / - embedding_dim; - - cache_t momentum_new = momentum + avg_square; - - cache_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); - cache_t correction; - - if constexpr(weight_decay_mode == 1) - { - correction = 1.0 - multiplier * karg.weight_decay; - } - else if constexpr(weight_decay_mode == 2) - { - correction = 1.0 - karg.learning_rate * karg.weight_decay; - } - else - { - correction = 1.0; - } - -// update new weight value -#pragma unroll - for(auto i = 0; i < thread_length; i++) - { - cache_t w = static_cast(weight[i]); - cache_t a = acc[i]; - w = correction * w - multiplier * a; - weight[i] = static_cast(w); - } - - p_momentum[row_index] = momentum_new; - } - } - - rowwise_adagrad_kernel_arg_t karg; -}; - -template -__device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( - const grad_t* p_output_grad, - emb_t* p_emb_table, - const int64_t* p_hash_size_cumsum, - const index_t* p_sorted_linear_indices_run, - const int32_t* p_sorted_linear_indices_cumulative_run_lengths, - const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} - const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, - uint32_t max_segment_length_per_warp, - uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, - uint32_t num_tables, - optimizer_karg_t opt_karg, - const float * p_sorted_indice_weights = nullptr) -{ - constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; - constexpr uint32_t length_mask = ~(segment_unroll - 1); - const uint32_t wave_id = __builtin_amdgcn_readfirstlane(threadIdx.x / AMDGCN_WAVE_SIZE); - const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; - const uint32_t run_id = wave_id + blockIdx.x * waves_per_block; - - if(run_id >= p_sorted_linear_indices_num_runs[0]) - { - return; - } - - const auto linear_index = p_sorted_linear_indices_run[run_id]; - - const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; - const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} - const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; - const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} - int64_t hash_size = p_hash_size_cumsum[t_0]; - - const int64_t emb_idx = linear_index - hash_size; - - p_emb_table += hash_size * emb_dim; - opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); - - const int32_t segment_length = segment_end - segment_start; - - if(segment_length >= max_segment_length_per_warp) - return; - - const int32_t segment_length_mod = segment_length & length_mask; - - cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; - grad_t grad_data[dword_per_row * segment_prefetch]; - emb_t emb_data[dword_per_row]; - float indice_weights[segment_unroll]; - - #pragma unroll - for(int i=0; i < dword_per_row; i++) - { - grad_acc[i] = .0f; - } - - int itr = 0; - if(segment_length_mod == 0) - goto L_tail_grad_acc; - - if constexpr (!weighted) { - #pragma unroll - for(int i = 0; i < segment_unroll; i++) - { - infos[i] = p_sorted_infos[segment_start + i]; - } - } else { - for(int i = 0; i < segment_unroll; i++) - { - infos[i] = p_sorted_infos[segment_start + i]; - indice_weights[i] = p_sorted_indice_weights[segment_start + i]; - } - } - - itr += segment_unroll; - p_sorted_infos += segment_unroll; - - if constexpr (weighted) { - p_sorted_indice_weights += segment_unroll; - } - - uint32_t bag_index; - uint32_t table_index; - - // LOOP - for(; itr < segment_length_mod; itr += segment_unroll) - { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} - table_index = infos[0] >> info_B_num_bits; - bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} - table_index = infos[1] >> info_B_num_bits; - bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - if constexpr (!weighted){ - #pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} - table_index = infos[j] >> info_B_num_bits; - bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} - table_index = infos[j + 1] >> info_B_num_bits; - bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - - #pragma unroll - for(int i = 0; i < segment_unroll; i++) - { - infos[i] = p_sorted_infos[segment_start + i]; - } - p_sorted_infos += segment_unroll; - - - } else { - #pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} - table_index = infos[j] >> info_B_num_bits; - bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} - table_index = infos[j + 1] >> info_B_num_bits; - bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); - - #pragma unroll - for(int i = 0; i < segment_unroll; i++) - { - infos[i] = p_sorted_infos[segment_start + i]; - indice_weights[i] = p_sorted_indice_weights[segment_start + i]; - } - p_sorted_infos += segment_unroll; - p_sorted_indice_weights += segment_unroll; - } - } - - // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} - table_index = infos[0] >> info_B_num_bits; - bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} - table_index = infos[1] >> info_B_num_bits; - bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - if constexpr (!weighted) { - #pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} - table_index = infos[j] >> info_B_num_bits; - bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} - table_index = infos[j + 1] >> info_B_num_bits; - bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id); - } else { - #pragma unroll - for(int j = 2; j < segment_unroll; j += 2) - { - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} - table_index = infos[j] >> info_B_num_bits; - bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} - table_index = infos[j + 1] >> info_B_num_bits; - bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - } - - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id, indice_weights[segment_unroll-2]); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[segment_unroll-1]); - } - -L_tail_grad_acc: - if(segment_length & (segment_unroll - 1)) - { - if constexpr (!weighted){ - // last, load one by one - do - { - infos[0] = p_sorted_infos[segment_start]; - p_sorted_infos++; - - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} - table_index = infos[0] >> info_B_num_bits; - bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id); - - itr++; - } while(itr < segment_length); - } else { - do - { - infos[0] = p_sorted_infos[segment_start]; - indice_weights[0] = p_sorted_indice_weights[segment_start]; - p_sorted_infos++; - p_sorted_indice_weights++; - - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} - table_index = infos[0] >> info_B_num_bits; - bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - accumulate_row_per_warp::run( - &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); - - itr++; - } while(itr < segment_length); - } - } - - // load the old emb weight data - load_row_per_warp::run( - &emb_data[0], emb_idx, p_emb_table, lane_id); - optimizer_t optimizer(opt_karg); - optimizer.template update(grad_acc, emb_data, emb_idx); - - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); -} -} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 65a21c99c5..0ec166892c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -3677,7 +3677,6 @@ def __init__( torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) assert self.D_offsets.numel() == T + 1 - # Required for VBE self.register_buffer( "feature_dims", diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 1d19e12a82..4de09dfed0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -181,6 +181,7 @@ def __init__( "D_offsets", torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) + assert self.D_offsets.numel() == T + 1 hash_size_cumsum = [0] + list(itertools.accumulate(rows)) if hash_size_cumsum[-1] == 0: diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h deleted file mode 100644 index b55fd72fce..0000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - ******************************************************************************/ -#pragma once - -#include -#include -#include - -#define HIP_CHECK(c) \ - { \ - if (c != hipSuccess) { \ - printf("HIP Error : %s", hipGetErrorString(c)); \ - printf(" %s %d\n", __FILE__, __LINE__); \ - exit(c); \ - } \ - } - -namespace fbgemm_gpu::rocm { - -[[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; - int device_id = 0; - HIP_CHECK(hipGetDevice(&device_id)); - hipDeviceProp_t dev_props; - HIP_CHECK(hipGetDeviceProperties(&dev_props, device_id)); - std::string gcn_arch = dev_props.gcnArchName; - gcn_arch = gcn_arch.substr(0, gcn_arch.find(":")); - return supported_archs.contains(gcn_arch); -} - -} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h deleted file mode 100644 index b3a56c4b52..0000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ /dev/null @@ -1,550 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - ******************************************************************************/ -#pragma once -#include -#include -#include - -/******************************************************************************/ -typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); -typedef float floatx2_t __attribute__((ext_vector_type(2))); -#define AMDGCN_BUFFER_RES_3 0x00027000 -#define AMDGCN_WAVE_SIZE 64 -#define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 - -namespace fbgemm_gpu::rocm { -template -union amdgcn_buffer_resource { - // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions - int32x4_t content; - struct { - T* address; - int32_t range; - int32_t config; - }; -}; - -template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { - amdgcn_buffer_resource buffer_resource; - buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; - buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 - - return buffer_resource.content; -} - -// buffer load fp32 -__device__ half llvm_amdgcn_raw_buffer_load_fp16( - int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); - -__device__ float llvm_amdgcn_raw_buffer_load_fp32( - int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); - -__device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( - int32x4_t srsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); - -__device__ void llvm_amdgcn_raw_buffer_store_fp32( - float vdata, - int32x4_t rsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); - -__device__ void llvm_amdgcn_raw_buffer_store_fp32x2( - floatx2_t vdata, - int32x4_t rsrc, - int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); - -/******************************************************************************/ - -template -struct load_row_per_warp { - static __device__ void run( - emb_t* emb_data, - index_t row_index, - const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); - } else { - emb_data[i] = 0.f; - } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); - } - } - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); - } -}; - -template < - typename emb_t, - int32_t embedding_dim, - typename output_t, - bool weighted> -struct accumulate_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void - run(output_t* acc, emb_t* emb_data, int lane_id, float row_weight = 1.0) { - if constexpr (!weighted) { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast(emb_data[i]); - } - } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); - } - } - } -}; - -template -struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } - } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); - } -}; - -// Helper function to pack fp16 and fp32 into int to further pass -// into mov_dpp and readfirstlane() -template - requires( - (sizeof(to_t) == 4 || sizeof(to_t) == 2) && - (sizeof(from_t) == 4 || sizeof(from_t) == 2)) -__device__ to_t pack(const from_t& v) { - to_t result = 0; - if constexpr (sizeof(to_t) == sizeof(from_t)) { - result = __builtin_bit_cast(to_t, v); - return result; - } - - memcpy(&result, &v, 2); - - return result; -} - -namespace reduce_op { -struct sum {}; -struct sub {}; -struct mul {}; -struct div {}; -} // namespace reduce_op - -template -struct reduce_op_sum_t { - __device__ data_t operator()(const data_t& a, const data_t& b) { - return a + b; - } -}; - -#define DPP_REDUCE(OP, TYPE) \ - __asm__ volatile( \ - "v_nop\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 quad_perm:[1,0,3,2]\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 quad_perm:[2,3,0,1]\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 row_shr:4\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 row_shr:8\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 row_bcast:15\n" \ - "v_nop\n" \ - "v_nop\n" \ - "v_" #OP "_" #TYPE \ - "_dpp %0 %0 %0 row_bcast:31\n" \ - "v_nop\n" \ - "v_nop\n" \ - : "=v"(result) \ - : "0"(result)) - -#define DPP_REDUCE_F16_F32(OP) \ - if constexpr (std::is_same_v) { \ - DPP_REDUCE(OP, f32); \ - } \ - \ - if constexpr (std::is_same_v) { \ - DPP_REDUCE(OP, f16); \ - } - -template -__device__ __forceinline__ void generic_dpp_reduction(data_t& result) { - constexpr int row_mask = 0xf; - constexpr int bank_mask = 0xf; - constexpr bool bound_ctrl = false; - - reduce_op_t reduce_op; - - if constexpr (wave_size > 1) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0xb1, - row_mask, - bank_mask, - bound_ctrl))); // quad_perm:[1,0,3,2] - } - if constexpr (wave_size > 2) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0x4e, - row_mask, - bank_mask, - bound_ctrl))); // quad_perm:[2,3,0,1] - } - if constexpr (wave_size > 4) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0x114, - row_mask, - bank_mask, - bound_ctrl))); // row_shr:4 - } - if constexpr (wave_size > 8) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0x118, - row_mask, - bank_mask, - bound_ctrl))); // row_shr:8 - } - if constexpr (wave_size > 16) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0x142, - row_mask, - bank_mask, - bound_ctrl))); // row_bcast:15 - } - if constexpr (wave_size > 32) { - result = reduce_op( - result, - pack(__builtin_amdgcn_mov_dpp( - pack(result), - 0x143, - row_mask, - bank_mask, - bound_ctrl))); // row_bcast:31 - } -} - -// Use corresponding assebly instruction for dpp reduction in case -// of trivial operation with an option to use custom operation -template -__device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) - if constexpr (std::is_same_v) { - DPP_REDUCE_F16_F32(add); - return; - } else if constexpr (std::is_same_v) { - DPP_REDUCE_F16_F32(sub); - return; - } else if constexpr (std::is_same_v) { - DPP_REDUCE_F16_F32(mul); - return; - } else if constexpr (std::is_same_v) { - DPP_REDUCE_F16_F32(div); - return; - } else { - generic_dpp_reduction(result); - } -#endif -} - -template -__device__ inline data_t wave_reduce(const data_t& thread_data) { - data_t result = thread_data; - - // now the reduced value is in the last lane of wave - dpp_reduction(result); - return pack( - __builtin_amdgcn_readlane(pack(result), wave_size - 1)); -} - -struct rowwise_adagrad_kernel_arg_t { - void* p_momentum; - float eps; - float learning_rate; - float weight_decay; - int64_t weight_decay_mode; -}; - -typedef struct { - uint32_t magic; - uint32_t shift; // actually 8 bit is enough -} magic_div_u32_t; - -static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for (shift = 0; shift < 32; shift++) - if ((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; -} - -// numer / denom = quotient, reminder -__device__ inline uint32_t magic_div_u32_run( - const magic_div_u32_t& mdiv, - const uint32_t& n) { - uint32_t tmp = __umulhi(n, mdiv.magic); - return (tmp + n) >> mdiv.shift; -} - -__device__ inline void magic_div_u32_run_with_mod( - const magic_div_u32_t& mdiv, - const uint32_t& n, - const uint32_t d, - uint32_t& quo, - uint32_t& rem) { - quo = magic_div_u32_run(mdiv, n); - rem = n - quo * d; -} -} // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index 48b1df66ed..f744186693 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -33,12 +33,11 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_unavailable, optests, running_on_github, running_on_rocm + from test_utils import gpu_unavailable, optests, running_on_rocm else: from fbgemm_gpu.test.test_utils import ( # noqa: F401 gpu_unavailable, # noqa: F401 optests, # noqa: F401 - running_on_github, # noqa: F401 running_on_rocm, # noqa: F401 ) diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index a19579bd9a..6250c529aa 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -43,7 +43,6 @@ generate_cache_tbes, gpu_unavailable, optests, - running_on_github, running_on_rocm, TestingStatsReporter, TestingStatsReporterConfig, @@ -78,7 +77,6 @@ def _compute_grad_output_shape( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -452,7 +450,6 @@ def assert_event_not_exist(event_name: str) -> None: @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -481,7 +478,6 @@ def test_cache_prefetch_pipeline( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -511,7 +507,6 @@ def test_cache_prefetch_pipeline_stream_1( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_github) @unittest.skipIf(*running_on_rocm) @given( T=st.integers(min_value=1, max_value=5), @@ -593,7 +588,6 @@ def test_get_prefetch_passes( self.assertTrue(torch.equal(torch.full_like(output_tensor, 1), output_tensor)) @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_github) @given( L=st.integers(min_value=0, max_value=16), H=st.integers(min_value=512, max_value=1024), diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index d5bdf4eb66..bbfe2ab326 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -59,7 +59,6 @@ additional_decorators, gpu_unavailable, optests, - skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -68,7 +67,6 @@ additional_decorators, gpu_unavailable, optests, - skipIfNotRocm, TEST_WITH_ROCM, use_cpu_strategy, ) @@ -1174,80 +1172,6 @@ def test_backward_optimizers_adagrad( # noqa C901 weight_decay_mode, ) - @given( - T=st.integers(min_value=1, max_value=5), - D=st.sampled_from([16, 32, 40, 48, 64]), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=2, max_value=20), - weighted=st.booleans(), - mixed=st.just(False), - mixed_B=st.just(False), - optimizer=st.sampled_from( - [ - OptimType.EXACT_ROWWISE_ADAGRAD, - ] - ), - long_segments=st.booleans(), - pooling_mode=st.sampled_from( - [ - PoolingMode.SUM, - ] - ), - use_cpu=st.just(False), - weight_decay_mode=st.sampled_from( - [ - WeightDecayMode.NONE, - WeightDecayMode.L2, - WeightDecayMode.DECOUPLE, - ] - ), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - @unittest.skipIf(*gpu_unavailable) - @skipIfNotRocm("Test only evaluates ROCm optimized kernels") - def test_new_bwd_kernel( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weighted: bool, - mixed: bool, - mixed_B: bool, - optimizer: OptimType, - long_segments: bool, - pooling_mode: PoolingMode, - use_cpu: bool, - weight_decay_mode: WeightDecayMode, - ) -> None: - if ( - pooling_mode == PoolingMode.NONE - or optimizer != OptimType.EXACT_ROWWISE_ADAGRAD - ): - mixed_B = False - self.execute_backward_optimizers_( - T, - D, - B, - log_E, - L, - weighted, - mixed, - mixed_B, - optimizer, - long_segments, - pooling_mode, - use_cpu, - weight_decay_mode, - ) - @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 522c822e57..6772b1e5e5 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -260,26 +260,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -# pyre-fixme[3]: Return annotation cannot be `Any`. -def skipIfNotRocm( - reason: str = "Test currently doesn work only on the ROCm stack", -) -> Any: - # pyre-fixme[3]: Return annotation cannot be `Any`. - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. - def decorator(fn: Callable) -> Any: - @wraps(fn) - # pyre-fixme[3]: Return annotation cannot be `Any`. - def wrapper(*args: Any, **kwargs: Any) -> Any: - if TEST_WITH_ROCM: - fn(*args, **kwargs) - else: - raise unittest.SkipTest(reason) - - return wrapper - - return decorator - - # pyre-fixme[3]: Return annotation cannot be `Any`. def skipIfRocmLessThan(min_version: int) -> Any: # pyre-fixme[3]: Return annotation cannot be `Any`.