-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
python3 test_flash_mm.py
got error
#1
Comments
This is usually a result of a mixmatch in CUDA versions: https://forums.developer.nvidia.com/t/cudalaunchkernel-returned-status-98-invalid-device-function/169958 Can you try it with the NVIDIA PyTorch docker container? https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch |
is it functional correctly despite of CUDA mismatch? I'm running mm-bert and the loss is going down as usual. |
The training loop is falling back to regular PyTorch, so that’s why the loss is going down. |
I see. Thank @DanFu09 |
May I ask one more question @DanFu09, I wonder how much faster the flash_mm kernel compare to pytorch implementation? |
I still can run the trainer and the loss go down,
The text was updated successfully, but these errors were encountered: