Skip to content

[ROCm] Improve softmax performance. #1740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 135 commits into
base: release/2.4
Choose a base branch
from

Conversation

doru1004
Copy link

This patch improves the performance of softmax for 2D tensors by:

  • using a softmax calculation which eliminates the increase of shared memory usage with the size of the tensor and relies on global memory accesses for the tensor data accesses while still using shared memory for the actual reduction step (the shared memory used for the reduction is constant and does not increase with tensor size).
  • for the final computation replacing the division by the sum with the multiplication of 1/sum. The 1/sum is computed as the last step of the warp reduction.
  • replace the use of the exp function with the __expf function.

The impact on numerical accuracy is within a 1e-5 for half precision and 1e-7 for full precision.

The impact on performance for MI300X is between 22% and 50% percentage improvement over current runtimes.

pruthvistony and others added 30 commits August 1, 2024 20:13
- Fortran package installation moved after gcc
- Update libtinfo search code in cmake1
- Install libstdc++.so
Reversed the condition as required
(cherry picked from commit 9848db1)
(cherry picked from commit ae01701)
…pired (#1399)

* Skip certificate check only for CentOS7 since certificate expired

* Naming
* Triton build conditionalized on ROCM_VERSION

(cherry picked from commit 1a7e1fa)

* Update pinned commit for rocm6.1 conditionalisation

---------

Co-authored-by: Pruthvi Madugundu <[email protected]>
…rsion (#1410)

* Include ROCm patch version in triton version

* Always include patch version

(cherry picked from commit 9692570)
…MENTS

since we plan to use builder repo to set it for older and newer branches.

Otherwise we end up with duplicate triton dependency specification eg.
PYTORCH_EXTRA_INSTALL_REQUIREMENTS='pytorch-triton-rocm==2.3.0+rocm6.1.0.4804a0dd4a | pytorch-triton-rocm==2.3.0+rocm6.1.0.4804a0dd4a'
… sync (#1455) (#1472)

* [SWDEV-469514] hipGraphExecDestroy requires an explicit sync

There is a new hip feature where they do not free hipGraph memory
as soon as hipGraphExecDestroy is called. This is to support async
work on the GPU. See this for more details:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-user-objects

We noticed this issue when an allreduce op inside a hipGraph hung.
Essentially, ncclCommAbort was waiting for all GPU activity to finish.
However, since hipGraph memory was technically still in use, we had an
infinite hang. So, I added an extra hipDeviceSynchronize in CUDAGraph's
destructor to esure that memory is freed and got
test_allreduce_in_cudagraph UT to pass.

However, when I ran this on CUDA machine, I noticed that they did not
require this extra sync in order to successfully run the UT. It seems
that they were calling cudaGraphInstantiateWithFlags with
cudaGraphInstantiateFlagAutoFreeOnLaunch, which aggressively frees
memory after graph lauch. There is support for this API in our ROCm
stack, but we were missing cuda to hip mappings in PyTorch. So, I
brought them in and added the necesary conditions to call this API in
HIP case also.

* Update comments

* Use USE_ROCM in keeping with convention

* Use USE_ROCM to match convention

---------

Co-authored-by: Jithun Nair <[email protected]>
(cherry picked from commit e752b4f)
…) (#1510)

* cudagraph explicit sync only after capture_begin

* use 'capture_dev_=-1' as not initialized value

* use named constant instead of magic '-1' value

(cherry picked from commit eb433b9)
(cherry picked from commit 1feb1a8)

Co-authored-by: Jithun Nair <[email protected]>
…tree (#1494)

* Skip test__int_mm in 6.0

(cherry picked from commit bf4c478)

* [release/2.3] fix test_vmapvjpvjp and skip test_profiler_experimental_tree (#1460)

* fix test_vmapvjpvjp and skip test_profiler_experimental_tree (#1444)

(cherry picked from commit 7e96391)

* remove trailing spaces

---------

Co-authored-by: Ramana Cherukuri <[email protected]>
(cherry picked from commit 0766b9c)

* Reformat test_float8_basics for current rocm support (#1415)

(cherry picked from commit cb0e9ad)

---------

Co-authored-by: Jack Taylor <[email protected]>
Co-authored-by: Andres Lugo <[email protected]>
[release/2.4] Cherry-picks from release/2.3
[MPS][TYPE_PROMOTION] Fix Clamp (pytorch#130226)

Summary:
1. Fixed pytorch#130201 by adding type promotion.
2. Added proper tests.
3. Found torch's type promotion is different from numpy as follows:

```python
import torch
import numpy as np
np.clip(np.array([1], dtype=np.float32), np.array([1], dtype=np.int32), None).dtype  # dtype('float64')
torch.clamp(torch.tensor([1], dtype=torch.float32), torch.tensor([1], dtype=torch.int32)).dtype  # torch.float32
```

~Not sure the proper way to handle it, it causes numpy ref tests to fail.~
Reason here, so think I'm gonna xfail it:
https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264

Pull Request resolved: pytorch#130226
Approved by: https://github.com/malfet

(cherry picked from commit 99967e1)

Co-authored-by: Li-Huai (Allan) Lin <[email protected]>
[Doc] update guide install mkl-static from conda to pip (pytorch#130026)

<img width="619" alt="image" src="https://github.com/pytorch/pytorch/assets/8433590/4ac3ca68-57dc-42c7-ac7a-876dc377ebcf">

Conda intel channel is not avaliable now.
Use `pip` install instead of `conda`.

`Windows` and `Linux` are avaliable:
Binary list: https://pypi.org/project/mkl-static/#files

`MacOS` is avaliable for old version:
https://pypi.org/project/mkl-static/2021.3.0/#files

TODO:
1. cherry-pick to `release/2.4` branch, @atalman .
2. fix it also in `release/2.3` branch: pytorch#131853

Pull Request resolved: pytorch#130026
Approved by: https://github.com/jgong5, https://github.com/atalman

(cherry picked from commit 484852c)

Co-authored-by: Xu Han <[email protected]>
pytorch#133346)

fix for launching kernel invalid config error when calling embedding … (pytorch#130994)

…with large index

Fixes pytorch#130806
When an output size of 2147483648 (=131072*16384) is expected in the above issue, it throwed out the following error:
RuntimeError: HIP error: invalid configuration argument

What happened was that the second parameter passed to hipLaunchKernel was crazy {2147483648,1,1}.
Found two issues in the Indexing.cu:

1: ptrdiff_t was used but it is signed int,  outTotalSize >= 2147483648 can cause overflow when doing [this](https://github.com/pytorch/pytorch/blame/39493aa93419532957e6e5ee97cae842b53b8b59/aten/src/ATen/native/cuda/Indexing.cu#L1367):
2: On ROCm, std::min -> ::min did not work as expected when outTotalSize>=2147483648

As the result, 2147483648 was sent to hipLaunchKernel which the GPU does not support such a huge number since this number specifies the number of threads per block. The original code intended to set 128 threads per block, though this is debatable as the perf would not good for latest powerful GPUs (a TODO item to update for perf maybe?) , but at least it would not cause `invalid configuration argument` error.

[Test]
Run the same code snippet in the [issue](pytorch#130806), and print the output, its dim and numel(), which looks like below now:
```
output=tensor([[ 0.4044, -0.0244, -0.6865,  ..., -0.7800,  0.1175,  1.6726],
        [-1.0866, -0.1609,  0.3538,  ...,  1.9105,  0.7882,  1.1583],
        [-2.2079,  0.3736,  0.3610,  ..., -0.2658, -0.0459,  1.3077],
        ...,
        [ 0.8753, -0.7482, -0.1978,  ...,  0.9016,  1.1501, -0.5178],
        [-1.5845, -0.6277,  1.4520,  ...,  0.5733, -2.1198, -0.0915],
        [-0.6310, -1.0239, -0.1910,  ...,  0.4309,  0.1630,  0.3239]],
       device='cuda:0'), dim=2, numel=2147483648
```

Added a large tensor unit test too.
```
/pytorch# pytest test/nn/test_embedding.py -k test_large_tensors
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.9.19, pytest-7.3.2, pluggy-1.4.0
rootdir: /dockerx/development/pytorch
configfile: pytest.ini
plugins: flakefinder-1.1.0, rerunfailures-14.0, xdist-3.3.1, xdoctest-1.1.0, cpp-2.3.0, hypothesis-5.35.1
collected 288 items / 287 deselected / 1 selected
Running 1 items in this shard

test/nn/test_embedding.py .                                                                                                                                                        [100%]

=========================================================================== 1 passed, 287 deselected in 3.16s ============================================================================
```
Pull Request resolved: pytorch#130994
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell

(cherry picked from commit 637ab85)

Co-authored-by: hongxyan <[email protected]>
…nd for scaled dot product function in ROCm (#1818)

Fixes [10211](ROCm/frameworks-internal#10211) 

Deselecting the efficient attention backend for scaled dot product
function for ROCm architecture when casual parameter is True.
Reference - ROCm/aotriton#25 
We are aware of this issue but the precise ETA has not been determined
at the moment due to various missing functionalities in AOTriton,
(varlen, MQA, causal variants, etc.)

Co-authored-by: root <[email protected]>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 7, 2025

Jenkins build for 70a25e21a46ec6f8685e0b86d1a975f4aa248409 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

This PR adds a meta_registration for miopen_batch_norm in release/2.4 to
resolve this issue

```
NotImplementedError: aten::miopen_batch_norm: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered.
```

cherry-picked from upstream
pytorch@4e4182d

Co-authored-by: Jack Taylor <[email protected]>
(cherry picked from commit 0ca1e5b)
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 13, 2025

Jenkins build for 70a25e21a46ec6f8685e0b86d1a975f4aa248409 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 13, 2025

Jenkins build for 70a25e21a46ec6f8685e0b86d1a975f4aa248409 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 14, 2025

Jenkins build for 70a25e21a46ec6f8685e0b86d1a975f4aa248409 commit finished as ABORTED
Links: Blue Ocean view / Build artifacts

amd-sriram and others added 2 commits January 15, 2025 13:07
Update related_commits to pick up apex version update
This PR enables:
* using MIOpen OCL_mix backend for bf16 batchnorm with fp32 weights
(using torch autocast). This was required and tested for customer
workload using NCHW (which is the only memory_layout enabled).
* logging for MIOpen batchnorm using `PYTORCH_MIOPEN_EXTRA_LOGGING` env
var.

TODO in separate PR: Need to implement PyTorch unit tests for this
bf16/fp16 inputs + fp32 weights case.

(cherry picked from commit abbfe77)
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 15, 2025

Jenkins build for 70a25e21a46ec6f8685e0b86d1a975f4aa248409 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

… (#1839)

Improve performance of reduce sum for 3D shapes.

Pull Request resolved: pytorch#143137
Approved by: https://github.com/jeffdaily, https://github.com/eqy

Co-authored-by: Doru Bercea <[email protected]>
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 16, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@doru1004 doru1004 force-pushed the improve-softmax-performance-2.4 branch from 70a25e2 to 1dfe9db Compare January 17, 2025 16:38
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 17, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 17, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 27, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Jan 30, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Feb 11, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Feb 13, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as ABORTED
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Feb 19, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 14, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 17, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Mar 17, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 15, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Apr 15, 2025

Jenkins build for 65d155ce2b085af1a1e9605ec28faae2628cc8cf commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.