@@ -867,6 +867,98 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
867
867
return acl_tensor;
868
868
}
869
869
870
+ /* *
871
+ * @brief Fills a tensor with a scalar value.
872
+ *
873
+ * This function fills the destination tensor `acl_dst` with the scalar value
874
+ * `scalar`.
875
+ *
876
+ * @param ctx The context for the CANN backend operations.
877
+ * @param scalar The scalar value used to fill the tensor.
878
+ * @param acl_dst The destination tensor to be filled with the scalar value.
879
+ */
880
+ static void aclnn_fill_scalar (ggml_backend_cann_context& ctx, float scalar,
881
+ aclTensor* acl_dst) {
882
+ auto acl_scalar = aclCreateScalar (&scalar, aclDataType::ACL_FLOAT);
883
+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceFillScalar, acl_dst, acl_scalar);
884
+ ggml_cann_release_resources (ctx, acl_scalar);
885
+ }
886
+
887
+ /* *
888
+ * @brief Get or expand cached float32 tensors filled with scalar values.
889
+ *
890
+ * This function manages a cache of float32 tensors (zero-filled and one-filled).
891
+ * If the cache does not exist, it will initialize the cache with a zero tensor
892
+ * and a one tensor. If the requested tensor size exceeds the current cache
893
+ * capacity, the cache will be expanded accordingly. The function then returns
894
+ * an aclTensor created from the cached memory (either zero-filled or one-filled),
895
+ * depending on the input `value`.
896
+ *
897
+ * @param ctx The CANN backend context that manages cache memory.
898
+ * @param ne The tensor shape array (number of elements in each dimension).
899
+ * @param nb The stride size for each dimension.
900
+ * @param dims The number of tensor dimensions.
901
+ * @param value The scalar value (only supports 0 or 1) used to determine whether
902
+ * to return the zero-cache tensor or the one-cache tensor.
903
+ * @return An aclTensor pointer corresponding to the cached tensor.
904
+ */
905
+ static aclTensor* get_f32_cache_acl_tensor (ggml_backend_cann_context& ctx,
906
+ int64_t * ne, size_t * nb,
907
+ int64_t dims, int64_t value) {
908
+ // init cache
909
+ if (ctx.f32_zero_cache == nullptr ) {
910
+ // zero-cache pool init
911
+ size_t size = ctx.f32_cache_element * sizeof (float );
912
+ ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
913
+ ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
914
+
915
+ // one-cache pool init
916
+ int64_t pool_ne[1 ] = { ctx.f32_cache_element };
917
+ size_t pool_nb[1 ] = { sizeof (float ) };
918
+ ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
919
+ aclTensor* acl_one = ggml_cann_create_tensor (
920
+ ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
921
+ 1 );
922
+ aclnn_fill_scalar (ctx, 1 , acl_one);
923
+ ggml_cann_release_resources (ctx, acl_one);
924
+ }
925
+
926
+ // Cache expansion
927
+ int64_t n_element = 1 ;
928
+ for (int i = 0 ; i < dims; i++) {
929
+ n_element = n_element * ne[i];
930
+ }
931
+ if (ctx.f32_cache_element < n_element) {
932
+ // free old mem
933
+ aclrtFree (ctx.f32_zero_cache );
934
+ aclrtFree (ctx.f32_one_cache );
935
+ // init zero cache
936
+ ctx.f32_cache_element = n_element;
937
+ size_t size = n_element * sizeof (float );
938
+ ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
939
+ ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
940
+
941
+ // one-cache pool init
942
+ int64_t pool_ne[1 ] = { n_element };
943
+ size_t pool_nb[1 ] = { sizeof (float ) };
944
+ ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
945
+ aclTensor* acl_one = ggml_cann_create_tensor (
946
+ ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
947
+ 1 );
948
+ aclnn_fill_scalar (ctx, 1 , acl_one);
949
+ ggml_cann_release_resources (ctx, acl_one);
950
+ }
951
+
952
+ void * cache;
953
+ if (value == 0 ) {
954
+ cache = ctx.f32_zero_cache ;
955
+ } else {
956
+ cache = ctx.f32_one_cache ;
957
+ }
958
+
959
+ return ggml_cann_create_tensor (cache, ACL_FLOAT, sizeof (float ), ne, nb, dims);
960
+ }
961
+
870
962
void ggml_cann_rms_norm (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
871
963
ggml_tensor* src = dst->src [0 ];
872
964
@@ -875,20 +967,23 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
875
967
876
968
float eps;
877
969
memcpy (&eps, dst->op_params , sizeof (float ));
878
- size_t one_tensor_n_bytes = src->ne [0 ] * ggml_element_size (src);
879
- ggml_cann_pool_alloc one_tensor_allocator (ctx.pool (), one_tensor_n_bytes);
880
970
881
- aclTensor* acl_gamma = aclnn_values (
882
- ctx, one_tensor_allocator.get (), one_tensor_n_bytes, src->ne , 1 ,
883
- ggml_cann_type_mapping (src->type ), ggml_element_size (src));
884
-
885
- size_t zero_tensor_n_bytes =
886
- src->ne [1 ] * src->ne [2 ] * src->ne [3 ] * ggml_element_size (src);
887
- ggml_cann_pool_alloc zero_tensor_allocator (ctx.pool (), zero_tensor_n_bytes);
888
- aclTensor* acl_rstd =
889
- aclnn_zero (ctx, zero_tensor_allocator.get (), zero_tensor_n_bytes,
890
- src->ne , GGML_MAX_DIMS, ggml_cann_type_mapping (src->type ),
891
- ggml_element_size (src));
971
+ // build gamma, one...
972
+ size_t acl_gamma_nb[GGML_MAX_DIMS];
973
+ acl_gamma_nb[0 ] = sizeof (float );
974
+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
975
+ acl_gamma_nb[i] = acl_gamma_nb[i - 1 ] * src->ne [i - 1 ];
976
+ }
977
+ aclTensor* acl_gamma = get_f32_cache_acl_tensor (ctx, src->ne , acl_gamma_nb, 1 , 1 );
978
+
979
+ // build rstd, zero...
980
+ size_t acl_rstd_nb[GGML_MAX_DIMS];
981
+ acl_rstd_nb[0 ] = sizeof (float );
982
+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
983
+ acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * src->ne [i - 1 ];
984
+ }
985
+ aclTensor* acl_rstd = get_f32_cache_acl_tensor (ctx, src->ne , acl_rstd_nb, GGML_MAX_DIMS, 0 );
986
+
892
987
GGML_CANN_CALL_ACLNN_OP (ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
893
988
ggml_cann_release_resources (ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
894
989
}
@@ -1277,23 +1372,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1277
1372
tmp_permute_tensor, tmp_mul_tensor, acl_dst);
1278
1373
}
1279
1374
1280
- /* *
1281
- * @brief Fills a tensor with a scalar value.
1282
- *
1283
- * This function fills the destination tensor `acl_dst` with the scalar value
1284
- * `scalar`.
1285
- *
1286
- * @param ctx The context for the CANN backend operations.
1287
- * @param scalar The scalar value used to fill the tensor.
1288
- * @param acl_dst The destination tensor to be filled with the scalar value.
1289
- */
1290
- static void aclnn_fill_scalar (ggml_backend_cann_context& ctx, float scalar,
1291
- aclTensor* acl_dst) {
1292
- auto acl_scalar = aclCreateScalar (&scalar, aclDataType::ACL_FLOAT);
1293
- GGML_CANN_CALL_ACLNN_OP (ctx, InplaceFillScalar, acl_dst, acl_scalar);
1294
- ggml_cann_release_resources (ctx, acl_scalar);
1295
- }
1296
-
1297
1375
/* *
1298
1376
* @brief Raises each element of a tensor to the power of the corresponding
1299
1377
* element in another tensor.
0 commit comments