Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCM EP convolution fails due to missing #19566

Open
dmnieto opened this issue Feb 19, 2024 · 12 comments
Open

ROCM EP convolution fails due to missing #19566

dmnieto opened this issue Feb 19, 2024 · 12 comments
Labels
ep:ROCm questions/issues related to ROCm execution provider stale issues that have not been addressed in a while; categorized by a bot

Comments

@dmnieto
Copy link

dmnieto commented Feb 19, 2024

Describe the issue

"MIOpen Error: No invoker was registered for convolution forward." happens when trying to use any model for inference with convolution codes. This is because the caching system in the Update() call will check for previous used algorithms. But the ROCM API (contrary to CUDA) requires the algo search call as described here:
https://rocmdocs.amd.com/projects/MIOpen/en/latest/MIOpen_Porting_Guide.html

To reproduce

This has been reproduced with immich latest version, on all the models being used.

Urgency

Relatively urgent as ROCM EP is broken.

Platform

Linux

OS Version

Ubuntu 22.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Other / Unknown

Execution Provider Library Version

ROCMm - any version

@skottmckay skottmckay added the ep:ROCm questions/issues related to ROCm execution provider label Feb 20, 2024
@cloudhan
Copy link
Contributor

cloudhan commented Feb 20, 2024

@PeixuanZuo Any idea, I think we didn't observe conv issue in SD.

@jeffdaily Could you please contact MIOpen dev for this. I think we introduced the manual cache for this since some older version of MIOpen, but it seems we need to remove the manual caching logic for this again?

@dmnieto
Copy link
Author

dmnieto commented Feb 20, 2024 via email

@jeffdaily
Copy link
Contributor

I have reached out to our MIOpen team.

@dmnieto
Copy link
Author

dmnieto commented Feb 20, 2024 via email

@jeffdaily
Copy link
Contributor

Can you provide more details? Docker container? With reproducer steps?

@dmnieto
Copy link
Author

dmnieto commented Feb 20, 2024 via email

Copy link
Contributor

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@github-actions github-actions bot added the stale issues that have not been addressed in a while; categorized by a bot label Mar 22, 2024
@tianleiwu
Copy link
Contributor

tianleiwu commented Jul 11, 2024

The root cause is the cache key is x shape (but not a combination of x shape + w shape) and w shape is not constant. When there is multiple threading:

  • thread 1 uses w shape w1 and x shape x1, it runs algo search then adds a cache of <x1, algo1> for w1.
  • thread 2 uses w shape w2 and x shape x1, it clear cache since w shape change; right after this, thread 1 insert cache <x1, algo1> before thread 2 looking up cache. Then thread 2 looks up cache with key x1 and found <x1, algo1> inserted by thread 1, then choose algo1, but algo1 cannot apply to w2, so thread 2 raise runtime error.

The solution is that we need to use x shape + w shape as key for cache, and never clear cache; and add mutex to guard the cache to make sure there is only one thread is looking up or updating the cache.

Current cache is applied to one node so Conv algo search will need multiple times when a model has multiple Conv nodes. A ideal solution is to have a global cache (like PyTorch code), and use all conv parameters (including device id) as key, that could avoid duplicated algo search.

@dmnieto
Copy link
Author

dmnieto commented Jul 11, 2024 via email

@mertalev
Copy link

The root cause is the cache key is x shape (but not a combination of x shape + w shape) and w shape is not constant. When there is multiple threading:

  • thread 1 uses w shape w1 and x shape x1, it runs algo search then adds a cache of <x1, algo1> for w1.
  • thread 2 uses w shape w2 and x shape x1, it clear cache since w shape change; right after this, thread 1 insert cache <x1, algo1> before thread 2 looking up cache. Then thread 2 looks up cache with key x1 and found <x1, algo1> inserted by thread 1, then choose algo1, but algo1 cannot apply to w2, so thread 2 raise runtime error.

The solution is that we need to use x shape + w shape as key for cache, and never clear cache; and add mutex to guard the cache to make sure there is only one thread is looking up or updating the cache.

Current cache is applied to one node so Conv algo search will need multiple times when a model has multiple Conv nodes. A ideal solution is to have a global cache (like PyTorch code), and use all conv parameters (including device id) as key, that could avoid duplicated algo search.

Hi @tianleiwu, I implemented these changes (using x shape + w shape in the key, not clearing cache, guarding cache with mutex) in this patch. However, based on these results, it seems the threading issue still exists. Do you have any ideas for why this might be?

I also noticed that there is already a std::lock_guard<OrtMutex> lock(s_.mutex); in ComputeInternal, called before UpdateState. Shouldn't that make this whole section serial?

@mertalev
Copy link

mertalev commented Jan 14, 2025

The specific error is this, which is different from the original error:

Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero     
immich_machine_learning_rocm  |                              status code returned while running Conv node.      
immich_machine_learning_rocm  |                              Name:'Conv_426' Status Message: MIOPEN failure 7:  
immich_machine_learning_rocm  |                              miopenStatusUnknownError ; GPU=0 ;                 
immich_machine_learning_rocm  |                              hostname=b86317b041da ;                            
immich_machine_learning_rocm  |                              file=/code/onnxruntime/onnxruntime/core/providers/r
immich_machine_learning_rocm  |                              ocm/nn/conv.cc ; line=336 ;                        
immich_machine_learning_rocm  |                              expr=miopenConvolutionForward(miopen_handle,       
immich_machine_learning_rocm  |                              &alpha, s_.x_tensor, s_.x_data, s_.w_desc,         
immich_machine_learning_rocm  |                              s_.w_data, s_.conv_desc, s_.fwd_algo, &beta,       
immich_machine_learning_rocm  |                              s_.y_tensor, s_.y_data, workspace.get(),           
immich_machine_learning_rocm  |                              s_.workspace_bytes);

My best guess is that it's not just a threading issue, but a caching issue where it tries to use an incompatible algo. But it does work when single-threaded, so I'm not sure.

@tianleiwu
Copy link
Contributor

I also noticed that there is already a std::lock_guard<OrtMutex> lock(s_.mutex); in ComputeInternal, called before UpdateState. Shouldn't that make this whole section serial?

@mertalev, you are right. That mutex shall make the op thread-safe. I missed that in previous review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:ROCm questions/issues related to ROCm execution provider stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

6 participants