@@ -886,12 +886,12 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
886
886
887
887
/* *
888
888
* @brief Get or expand cached float32 tensors filled with scalar values.
889
- *
890
- * This function manages a cache of float32 tensors (zero-filled and one-filled).
891
- * If the cache does not exist, it will initialize the cache with a zero tensor
892
- * and a one tensor. If the requested tensor size exceeds the current cache
893
- * capacity, the cache will be expanded accordingly. The function then returns
894
- * an aclTensor created from the cached memory (either zero-filled or one-filled),
889
+ *
890
+ * This function manages a cache of float32 tensors (zero-filled and one-filled).
891
+ * If the cache does not exist, it will initialize the cache with a zero tensor
892
+ * and a one tensor. If the requested tensor size exceeds the current cache
893
+ * capacity, the cache will be expanded accordingly. The function then returns
894
+ * an aclTensor created from the cached memory (either zero-filled or one-filled),
895
895
* depending on the input `value`.
896
896
*
897
897
* @param ctx The CANN backend context that manages cache memory.
@@ -902,7 +902,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
902
902
* to return the zero-cache tensor or the one-cache tensor.
903
903
* @return An aclTensor pointer corresponding to the cached tensor.
904
904
*/
905
- static aclTensor* get_f32_cache_acl_tensor (ggml_backend_cann_context& ctx,
905
+ static aclTensor* get_f32_cache_acl_tensor (ggml_backend_cann_context& ctx,
906
906
int64_t * ne, size_t * nb,
907
907
int64_t dims, int64_t value) {
908
908
// init cache
@@ -911,10 +911,10 @@ static aclTensor* get_f32_cache_acl_tensor(ggml_backend_cann_context& ctx,
911
911
size_t size = ctx.f32_cache_element * sizeof (float );
912
912
ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
913
913
ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
914
-
914
+
915
915
// one-cache pool init
916
- int64_t pool_ne[1 ] = { ctx.f32_cache_element };
917
- size_t pool_nb[1 ] = { sizeof (float ) };
916
+ int64_t pool_ne[1 ] = { ctx.f32_cache_element };
917
+ size_t pool_nb[1 ] = { sizeof (float ) };
918
918
ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
919
919
aclTensor* acl_one = ggml_cann_create_tensor (
920
920
ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
@@ -937,25 +937,25 @@ static aclTensor* get_f32_cache_acl_tensor(ggml_backend_cann_context& ctx,
937
937
size_t size = n_element * sizeof (float );
938
938
ACL_CHECK (aclrtMalloc (&ctx.f32_zero_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
939
939
ACL_CHECK (aclrtMemsetAsync (ctx.f32_zero_cache , size, 0 , size, ctx.stream ()));
940
-
940
+
941
941
// one-cache pool init
942
- int64_t pool_ne[1 ] = { n_element };
943
- size_t pool_nb[1 ] = { sizeof (float ) };
942
+ int64_t pool_ne[1 ] = { n_element };
943
+ size_t pool_nb[1 ] = { sizeof (float ) };
944
944
ACL_CHECK (aclrtMalloc (&ctx.f32_one_cache , size, ACL_MEM_MALLOC_HUGE_FIRST));
945
945
aclTensor* acl_one = ggml_cann_create_tensor (
946
946
ctx.f32_one_cache , ACL_FLOAT, sizeof (float ), pool_ne, pool_nb,
947
947
1 );
948
948
aclnn_fill_scalar (ctx, 1 , acl_one);
949
949
ggml_cann_release_resources (ctx, acl_one);
950
950
}
951
-
951
+
952
952
void * cache;
953
953
if (value == 0 ) {
954
954
cache = ctx.f32_zero_cache ;
955
955
} else {
956
956
cache = ctx.f32_one_cache ;
957
957
}
958
-
958
+
959
959
return ggml_cann_create_tensor (cache, ACL_FLOAT, sizeof (float ), ne, nb, dims);
960
960
}
961
961
@@ -983,7 +983,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
983
983
acl_rstd_nb[i] = acl_rstd_nb[i - 1 ] * src->ne [i - 1 ];
984
984
}
985
985
aclTensor* acl_rstd = get_f32_cache_acl_tensor (ctx, src->ne , acl_rstd_nb, GGML_MAX_DIMS, 0 );
986
-
986
+
987
987
GGML_CANN_CALL_ACLNN_OP (ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
988
988
ggml_cann_release_resources (ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
989
989
}
0 commit comments