Skip to content

[CK_TILE] Support for elementwise kernel #2246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: develop
Choose a base branch
from

Conversation

amd-yashagar
Copy link

Proposed changes

Implementation for the elementwise kernel that supports arbitrary number of input tensors for arbitrary dimensions.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@AviralGoelAMD
Copy link
Collaborator

Hey team, I made a few suggestions in another PR (feel free to keep/discard).
#2248

@amd-yashagar amd-yashagar force-pushed the LWPCK-3170-elementwise-kernel-general branch from e4271dd to 02bfa62 Compare May 28, 2025 08:27
Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start! Please add unit-tests (functional) here: https://github.com/ROCm/composable_kernel/tree/develop/test/ck_tile

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please change the example number? To 19 ? (At this moment last is 18) I don't understand why sb started with thirty sth...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 4145fbf.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add copyright header here and in other files.

Copy link
Contributor

@SamiAario-AMD SamiAario-AMD May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit bcccb72.

# Elementwise example targets 2D inputs
set(TARGET_NAME_2D_INPUT tile_example_elementwise)
add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp)
rocm_install(TARGETS ${TARGET_NAME_2D_INPUT} COMPONENT examples)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need this at the moment. @illsilin Could you please comment on this?

Copy link
Contributor

@SamiAario-AMD SamiAario-AMD Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rocm_install lines removed from our CMakeLists.txt in commit 96d1a6b. I went ahead with this since none of the other examples have a corresponding line.

int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");

assert(stride >= N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert function is a no-op (disabled) in a release build mode. If you want to validate it, throw an error if has incorrect value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in commit 05bbba7.

// Copy data from host tensors to device buffers.
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());
y_buf.ToDevice(y_host.data());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd rather want to set this to zero (if it's needed, or just for sanity). Right now you copy from uninitialized (or in best scenario default initialized) memory region.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in commit 64a7d3f where we no longer copy y_buf to the device. Please advise if this is not what you meant.


auto y = y_tile(tile_idx);

apply_operation(ElementWiseOperation{}, y, x_values);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this approach seem to be nice generic impl for associative operations - we do not always run such simple operations. Therefore we'd rather pass all inputs to the elementwise functor as arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now addressed in commit 8c5d714.

Comment on lines 136 to 138
// Move windows to next block
move_tile_windows(x_windows);
move_tile_window(y_window, {S::kBlockM});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed if each of your workgroups is processing just single tile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in commit 876b057.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the place for unary elementwise ops: https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
You may want to add eg binary_elementwise_operation.hpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in commit fc1ff7f.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our discussion on Teams, we will keep using the existing ElementWiseShape because it is slightly more appropriate for our implementation.

Comment on lines 7 to 8
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not used here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in commit abd0d35.

@SamiAario-AMD SamiAario-AMD force-pushed the LWPCK-3170-elementwise-kernel-general branch from 559af12 to 5f16d7e Compare May 30, 2025 11:50
…eration with the exact number of arguments in xs
@SamiAario-AMD SamiAario-AMD force-pushed the LWPCK-3170-elementwise-kernel-general branch from 5f16d7e to 8c5d714 Compare May 30, 2025 12:26
@SamiAario-AMD SamiAario-AMD force-pushed the LWPCK-3170-elementwise-kernel-general branch from 2964594 to 05bbba7 Compare May 30, 2025 13:19
@SamiAario-AMD SamiAario-AMD force-pushed the LWPCK-3170-elementwise-kernel-general branch from c1f4b26 to b52400c Compare June 2, 2025 12:01
@SamiAario-AMD SamiAario-AMD force-pushed the LWPCK-3170-elementwise-kernel-general branch from eb11196 to fc1ff7f Compare June 3, 2025 06:55
Damien Lejeune and others added 3 commits June 3, 2025 14:33
- Refer to AMD terminology except when suggesting NVIDIA alternatives in parentheses
- ElementWiseTraits was renamed to ElementWiseShape
- Adopt suggestions made by Copilot when prompted to check for factual or typographical errors
Copy link
Collaborator

@aosewski aosewski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good. I got mostly cosmetic changes.

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 99eac8e.

@@ -0,0 +1,142 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this fixed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename to elementwise_kernel.hpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 99eac8e.

auto make_tile_windows = [&](const auto& tensors, const auto& iM_) {
return generate_tuple(
[&](auto idx) {
auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could configure here cache coherence as well. The data won't be reused so we could set up this.
https://github.com/ROCm/composable_kernel/blob/develop/include/ck_tile/core/tensor/tensor_view.hpp#L433

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I am not sure anymore what you meant by this: #2246 (comment)

We switched this call to make_naive_tensor_view to use the default template parameters based on that note. In any case, now we have two calls to make_naive_tensor_view in our kernel with different template parameters, because I neglected to change the other one on line 67. The two calls should still use the same parameters, am I right?

Please confirm if you want the call on line 42 to be switched to:

                    auto tensor_view = make_naive_tensor_view<address_space_enum::global,
                                                              memory_operation_enum::set,
                                                              amd_buffer_coherence_enum::slc>(

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the copyright year.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 99eac8e.

@@ -0,0 +1,31 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright year.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 99eac8e.

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright year

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 99eac8e.

Comment on lines +267 to +274
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that impact performance or register usage?

@aosewski aosewski marked this pull request as ready for review June 4, 2025 11:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants