-
Notifications
You must be signed in to change notification settings - Fork 533
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: X-link: facebookresearch/FBGEMM#736 Bring ARM's PR: #3510 Differential Revision: D67396309
- Loading branch information
1 parent
79fcd5b
commit a9bc39a
Showing
5 changed files
with
822 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
@@ -180,6 +181,19 @@ def get_fbgemm_inline_sve_srcs(msvc = False, buck = False): | |
}) | ||
return asm_srcs if not msvc else intrinsics_srcs | ||
|
||
def get_fbgemm_inline_neon_srcs(msvc = False, buck = False): | ||
intrinsics_srcs = ["src/UtilsNeon.cc"] | ||
|
||
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different. | ||
asm_srcs = ["src/UtilsNeon.cc"] | ||
if buck: | ||
return select({ | ||
"DEFAULT": asm_srcs, | ||
"ovr_config//compiler:cl": intrinsics_srcs, | ||
"ovr_config//cpu:arm64": intrinsics_srcs, | ||
}) | ||
return asm_srcs if not msvc else intrinsics_srcs | ||
|
||
def get_fbgemm_autovec_srcs(): | ||
return [ | ||
"src/EmbeddingSpMDMAutovec.cc", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
|
@@ -47,9 +48,9 @@ void transpose_simd( | |
return; | ||
} | ||
|
||
#if HAVE_SVE | ||
#ifdef __aarch64__ | ||
if constexpr (std::is_same<T, float>::value) { | ||
internal::transpose_sve<T>(M, N, src, ld_src, dst, ld_dst); | ||
internal::transpose_neon<T>(M, N, src, ld_src, dst, ld_dst); | ||
} else { | ||
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
|
@@ -64,9 +65,9 @@ void transpose_avx512( | |
|
||
#ifdef __aarch64__ | ||
/** | ||
* @brief Transpose a matrix using Intel AVX2. | ||
* @brief Transpose a matrix using SVE. | ||
* | ||
* This is called if the code is running on a CPU with Intel AVX2 support. | ||
* This is called if the code is running on a CPU with SVE support. | ||
*/ | ||
template <typename T> | ||
void transpose_sve( | ||
|
@@ -76,6 +77,20 @@ void transpose_sve( | |
int64_t ld_src, | ||
T* dst, | ||
int64_t ld_dst); | ||
|
||
/** | ||
* @brief Transpose a matrix using NEON. | ||
* | ||
* This is called if the code is running on a CPU with NEON support. | ||
*/ | ||
template <typename T> | ||
void transpose_neon( | ||
int64_t M, | ||
int64_t N, | ||
const T* src, | ||
int64_t ld_src, | ||
T* dst, | ||
int64_t ld_dst); | ||
#endif // __aarch64__ | ||
|
||
} // namespace internal | ||
|
Oops, something went wrong.