-
Notifications
You must be signed in to change notification settings - Fork 188
[CK_TILE] Support for elementwise kernel #2246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Hey team, I made a few suggestions in another PR (feel free to keep/discard). |
e4271dd
to
02bfa62
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start! Please add unit-tests (functional) here: https://github.com/ROCm/composable_kernel/tree/develop/test/ck_tile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please change the example number? To 19 ? (At this moment last is 18) I don't understand why sb started with thirty sth...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in commit 4145fbf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add copyright header here and in other files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in commit bcccb72.
# Elementwise example targets 2D inputs | ||
set(TARGET_NAME_2D_INPUT tile_example_elementwise) | ||
add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp) | ||
rocm_install(TARGETS ${TARGET_NAME_2D_INPUT} COMPONENT examples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we need this at the moment. @illsilin Could you please comment on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rocm_install
lines removed from our CMakeLists.txt
in commit 96d1a6b. I went ahead with this since none of the other examples have a corresponding line.
int warmup = arg_parser.get_int("warmup"); | ||
int repeat = arg_parser.get_int("repeat"); | ||
|
||
assert(stride >= N); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert
function is a no-op (disabled) in a release build mode. If you want to validate it, throw an error if has incorrect value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in commit 05bbba7.
// Copy data from host tensors to device buffers. | ||
x_buf_a.ToDevice(x_host_a.data()); | ||
x_buf_b.ToDevice(x_host_b.data()); | ||
y_buf.ToDevice(y_host.data()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'd rather want to set this to zero (if it's needed, or just for sanity). Right now you copy from uninitialized (or in best scenario default initialized) memory region.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in commit 64a7d3f where we no longer copy y_buf
to the device. Please advise if this is not what you meant.
|
||
auto y = y_tile(tile_idx); | ||
|
||
apply_operation(ElementWiseOperation{}, y, x_values); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this approach seem to be nice generic impl for associative operations - we do not always run such simple operations. Therefore we'd rather pass all inputs to the elementwise functor as arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now addressed in commit 8c5d714.
// Move windows to next block | ||
move_tile_windows(x_windows); | ||
move_tile_window(y_window, {S::kBlockM}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed if each of your workgroups is processing just single tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in commit 876b057.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the place for unary elementwise ops: https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
You may want to add eg binary_elementwise_operation.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in commit fc1ff7f.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to reuse this one: https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/ops/common/generic_2d_block_shape.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per our discussion on Teams, we will keep using the existing ElementWiseShape
because it is slightly more appropriate for our implementation.
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" | ||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are not used here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in commit abd0d35.
Co-authored-by: Sami Aario <[email protected]> Co-authored-by: Mohsen Saffari <[email protected]> Co-authored-by: yashagar <[email protected]>
559af12
to
5f16d7e
Compare
…eration with the exact number of arguments in xs
5f16d7e
to
8c5d714
Compare
2964594
to
05bbba7
Compare
…rid size also for the unary example
… the other examples
c1f4b26
to
b52400c
Compare
…nce in a call to make_naive_tensor_view
…entwise_operation.hpp
eb11196
to
fc1ff7f
Compare
…ing the transpose reference
- Refer to AMD terminology except when suggesting NVIDIA alternatives in parentheses - ElementWiseTraits was renamed to ElementWiseShape - Adopt suggestions made by Copilot when prompted to check for factual or typographical errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good. I got mostly cosmetic changes.
@@ -0,0 +1,33 @@ | |||
// SPDX-License-Identifier: MIT | |||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. | |
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 99eac8e.
@@ -0,0 +1,142 @@ | |||
// SPDX-License-Identifier: MIT | |||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename to elementwise_kernel.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 99eac8e.
auto make_tile_windows = [&](const auto& tensors, const auto& iM_) { | ||
return generate_tuple( | ||
[&](auto idx) { | ||
auto tensor_view = make_naive_tensor_view<address_space_enum::global>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could configure here cache coherence as well. The data won't be reused so we could set up this.
https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/core/tensor/tensor_view.hpp#L433
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I am not sure anymore what you meant by this: #2246 (comment)
We switched this call to make_naive_tensor_view
to use the default template parameters based on that note. In any case, now we have two calls to make_naive_tensor_view
in our kernel with different template parameters, because I neglected to change the other one on line 67. The two calls should still use the same parameters, am I right?
Please confirm if you want the call on line 42 to be switched to:
auto tensor_view = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::set,
amd_buffer_coherence_enum::slc>(
@@ -0,0 +1,29 @@ | |||
// SPDX-License-Identifier: MIT | |||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the copyright year.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 99eac8e.
@@ -0,0 +1,31 @@ | |||
// SPDX-License-Identifier: MIT | |||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright year.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 99eac8e.
@@ -0,0 +1,24 @@ | |||
// SPDX-License-Identifier: MIT | |||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright year
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 99eac8e.
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); } | ||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would that impact performance or register usage?
Proposed changes
Implementation for the elementwise kernel that supports arbitrary number of input tensors for arbitrary dimensions.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered