Skip to content

Commit b235381

Browse files
committed
[CANN] Optimize RMS_NORM using cache
Signed-off-by: noemotiovon <[email protected]>
1 parent 6424594 commit b235381

File tree

2 files changed

+111
-30
lines changed

2 files changed

+111
-30
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,98 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
867867
return acl_tensor;
868868
}
869869

870+
/**
871+
* @brief Fills a tensor with a scalar value.
872+
*
873+
* This function fills the destination tensor `acl_dst` with the scalar value
874+
* `scalar`.
875+
*
876+
* @param ctx The context for the CANN backend operations.
877+
* @param scalar The scalar value used to fill the tensor.
878+
* @param acl_dst The destination tensor to be filled with the scalar value.
879+
*/
880+
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
881+
aclTensor* acl_dst) {
882+
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
883+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
884+
ggml_cann_release_resources(ctx, acl_scalar);
885+
}
886+
887+
/**
888+
* @brief Get or expand cached float32 tensors filled with scalar values.
889+
*
890+
* This function manages a cache of float32 tensors (zero-filled and one-filled).
891+
* If the cache does not exist, it will initialize the cache with a zero tensor
892+
* and a one tensor. If the requested tensor size exceeds the current cache
893+
* capacity, the cache will be expanded accordingly. The function then returns
894+
* an aclTensor created from the cached memory (either zero-filled or one-filled),
895+
* depending on the input `value`.
896+
*
897+
* @param ctx The CANN backend context that manages cache memory.
898+
* @param ne The tensor shape array (number of elements in each dimension).
899+
* @param nb The stride size for each dimension.
900+
* @param dims The number of tensor dimensions.
901+
* @param value The scalar value (only supports 0 or 1) used to determine whether
902+
* to return the zero-cache tensor or the one-cache tensor.
903+
* @return An aclTensor pointer corresponding to the cached tensor.
904+
*/
905+
static aclTensor* get_f32_cache_acl_tensor(ggml_backend_cann_context& ctx,
906+
int64_t* ne, size_t* nb,
907+
int64_t dims, int64_t value) {
908+
// init cache
909+
if(ctx.f32_zero_cache == nullptr) {
910+
// zero-cache pool init
911+
size_t size = ctx.f32_cache_element * sizeof(float);
912+
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
913+
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
914+
915+
// one-cache pool init
916+
int64_t pool_ne[1] = { ctx.f32_cache_element };
917+
size_t pool_nb[1] = { sizeof(float) };
918+
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
919+
aclTensor* acl_one = ggml_cann_create_tensor(
920+
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
921+
1);
922+
aclnn_fill_scalar(ctx, 1, acl_one);
923+
ggml_cann_release_resources(ctx, acl_one);
924+
}
925+
926+
// Cache expansion
927+
int64_t n_element = 1;
928+
for(int i = 0; i < dims; i++) {
929+
n_element = n_element * ne[i];
930+
}
931+
if (ctx.f32_cache_element < n_element) {
932+
// free old mem
933+
aclrtFree(ctx.f32_zero_cache);
934+
aclrtFree(ctx.f32_one_cache);
935+
// init zero cache
936+
ctx.f32_cache_element = n_element;
937+
size_t size = n_element * sizeof(float);
938+
ACL_CHECK(aclrtMalloc(&ctx.f32_zero_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
939+
ACL_CHECK(aclrtMemsetAsync(ctx.f32_zero_cache, size, 0, size, ctx.stream()));
940+
941+
// one-cache pool init
942+
int64_t pool_ne[1] = { n_element };
943+
size_t pool_nb[1] = { sizeof(float) };
944+
ACL_CHECK(aclrtMalloc(&ctx.f32_one_cache, size, ACL_MEM_MALLOC_HUGE_FIRST));
945+
aclTensor* acl_one = ggml_cann_create_tensor(
946+
ctx.f32_one_cache, ACL_FLOAT, sizeof(float), pool_ne, pool_nb,
947+
1);
948+
aclnn_fill_scalar(ctx, 1, acl_one);
949+
ggml_cann_release_resources(ctx, acl_one);
950+
}
951+
952+
void* cache;
953+
if (value == 0) {
954+
cache = ctx.f32_zero_cache;
955+
} else {
956+
cache = ctx.f32_one_cache;
957+
}
958+
959+
return ggml_cann_create_tensor(cache, ACL_FLOAT, sizeof(float), ne, nb, dims);
960+
}
961+
870962
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
871963
ggml_tensor* src = dst->src[0];
872964

@@ -875,20 +967,23 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
875967

876968
float eps;
877969
memcpy(&eps, dst->op_params, sizeof(float));
878-
size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
879-
ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
880970

881-
aclTensor* acl_gamma = aclnn_values(
882-
ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
883-
ggml_cann_type_mapping(src->type), ggml_element_size(src));
884-
885-
size_t zero_tensor_n_bytes =
886-
src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
887-
ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
888-
aclTensor* acl_rstd =
889-
aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
890-
src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
891-
ggml_element_size(src));
971+
// build gamma, one...
972+
size_t acl_gamma_nb[GGML_MAX_DIMS];
973+
acl_gamma_nb[0] = sizeof(float);
974+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
975+
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
976+
}
977+
aclTensor* acl_gamma = get_f32_cache_acl_tensor(ctx, src->ne, acl_gamma_nb, 1, 1);
978+
979+
// build rstd, zero...
980+
size_t acl_rstd_nb[GGML_MAX_DIMS];
981+
acl_rstd_nb[0] = sizeof(float);
982+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
983+
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
984+
}
985+
aclTensor* acl_rstd = get_f32_cache_acl_tensor(ctx, src->ne, acl_rstd_nb, GGML_MAX_DIMS, 0);
986+
892987
GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
893988
ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
894989
}
@@ -1277,23 +1372,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
12771372
tmp_permute_tensor, tmp_mul_tensor, acl_dst);
12781373
}
12791374

1280-
/**
1281-
* @brief Fills a tensor with a scalar value.
1282-
*
1283-
* This function fills the destination tensor `acl_dst` with the scalar value
1284-
* `scalar`.
1285-
*
1286-
* @param ctx The context for the CANN backend operations.
1287-
* @param scalar The scalar value used to fill the tensor.
1288-
* @param acl_dst The destination tensor to be filled with the scalar value.
1289-
*/
1290-
static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
1291-
aclTensor* acl_dst) {
1292-
auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
1293-
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
1294-
ggml_cann_release_resources(ctx, acl_scalar);
1295-
}
1296-
12971375
/**
12981376
* @brief Raises each element of a tensor to the power of the corresponding
12991377
* element in another tensor.

ggml/src/ggml-cann/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ struct ggml_backend_cann_context {
375375
cann_task_queue task_queue;
376376
bool async_mode;
377377
bool support_set_rows;
378+
void* f32_zero_cache = nullptr;
379+
void* f32_one_cache = nullptr;
380+
int64_t f32_cache_element = 1024 * 1024;
378381

379382
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
380383

0 commit comments

Comments
 (0)