Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3 test_flash_mm.py got error #1

Closed
tiendung opened this issue Aug 1, 2023 · 5 comments
Closed

python3 test_flash_mm.py got error #1

tiendung opened this issue Aug 1, 2023 · 5 comments

Comments

@tiendung
Copy link

tiendung commented Aug 1, 2023

ERROR: CUDA RT call "cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size )" in line 695 of file mm/csrc/flashmm/mm_block_fwd_cuda.cu failed with invalid device function (98).
max diff for mm block: tensor(2.0590e-05, device='cuda:0', grad_fn=<SelectBackward0>)
average diff for mm block: tensor(2.9658e-06, device='cuda:0', grad_fn=<MeanBackward0>)
max diff: tensor(0.0003, device='cuda:0')
avg diff: tensor(7.4159e-05, device='cuda:0')

I still can run the trainer and the loss go down,

@DanFu09
Copy link
Collaborator

DanFu09 commented Aug 1, 2023

This is usually a result of a mixmatch in CUDA versions: https://forums.developer.nvidia.com/t/cudalaunchkernel-returned-status-98-invalid-device-function/169958

Can you try it with the NVIDIA PyTorch docker container? https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

@tiendung
Copy link
Author

tiendung commented Aug 3, 2023

is it functional correctly despite of CUDA mismatch? I'm running mm-bert and the loss is going down as usual.

@DanFu09
Copy link
Collaborator

DanFu09 commented Aug 3, 2023

The training loop is falling back to regular PyTorch, so that’s why the loss is going down.

@tiendung
Copy link
Author

tiendung commented Aug 3, 2023

I see. Thank @DanFu09

@tiendung tiendung closed this as completed Aug 3, 2023
@tiendung
Copy link
Author

tiendung commented Aug 6, 2023

May I ask one more question @DanFu09, I wonder how much faster the flash_mm kernel compare to pytorch implementation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants