You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed a bug in FlashFFTConv that appears when implict padding must be done.
Specifically, if I use a sequence of shape [B, N, L=14113] with a FlashFFTConv module parameterized with length 32768, as:
# fftconv_fn = FlashFFTConv(seqlen=32768, dtype=torch.bfloat16, use_32_butterfly=True)
y = fftconv_fn(x.contiguous().to(dtype=fftconv_fn.dtype), k.float().contiguous()).to(dtype=x.dtype)
I get the following error:
CUDA Runtime Error at: /tmp/pip-req-build-pa3co3yr/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:1041
misaligned address
terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: misaligned address
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /opt/pytorch/pytorch/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7f4984c858f9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f4984c3abb6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3c2 (0x7f498efd2e12 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0xe5c485 (0x7f4926cb5485 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xe59644 (0x7f4926cb2644 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x483b00 (0x7f4983e1cb00 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #6: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f4984c61419 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #7: <unknown function> + 0x74b788 (0x7f49840e4788 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: THPVariable_subclass_dealloc(_object*) + 0x296 (0x7f49840e4a96 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x136991 (0x5556b8a99991 in /usr/bin/python)
frame #10: <unknown function> + 0x13678c (0x5556b8a9978c in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x5556b8a98c52 in /usr/bin/python)
frame #12: <unknown function> + 0x25b035 (0x5556b8bbe035 in /usr/bin/python)
frame #13: _PyEval_EvalFrameDefault + 0xa33b (0x5556b8aafeeb in /usr/bin/python)
frame #14: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #15: PyObject_Call + 0x122 (0x5556b8acc492 in /usr/bin/python)
frame #16: _PyEval_EvalFrameDefault + 0x2a27 (0x5556b8aa85d7 in /usr/bin/python)
frame #17: <unknown function> + 0x1687f1 (0x5556b8acb7f1 in /usr/bin/python)
frame #18: _PyEval_EvalFrameDefault + 0x198c (0x5556b8aa753c in /usr/bin/python)
frame #19: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x8ac (0x5556b8aa645c in /usr/bin/python)
frame #21: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x2a27 (0x5556b8aa85d7 in /usr/bin/python)
frame #23: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x6bd (0x5556b8aa626d in /usr/bin/python)
frame #25: <unknown function> + 0x13f9c6 (0x5556b8aa29c6 in /usr/bin/python)
frame #26: PyEval_EvalCode + 0x86 (0x5556b8b98256 in /usr/bin/python)
frame #27: <unknown function> + 0x23ae2d (0x5556b8b9de2d in /usr/bin/python)
frame #28: <unknown function> + 0x15ac59 (0x5556b8abdc59 in /usr/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x6bd (0x5556b8aa626d in /usr/bin/python)
frame #30: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x6bd (0x5556b8aa626d in /usr/bin/python)
frame #32: _PyFunction_Vectorcall + 0x7c (0x5556b8abd9fc in /usr/bin/python)
frame #33: <unknown function> + 0x252c2d (0x5556b8bb5c2d in /usr/bin/python)
frame #34: Py_RunMain + 0x128 (0x5556b8bb48c8 in /usr/bin/python)
frame #35: Py_BytesMain + 0x2d (0x5556b8b8b02d in /usr/bin/python)
frame #36: <unknown function> + 0x29d90 (0x7f49927d0d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #37: __libc_start_main + 0x80 (0x7f49927d0e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: _start + 0x25 (0x5556b8b8af25 in /usr/bin/python)
After looking into the problem a bit more in detail, I found that this only happens if implicit padding is required, i.e., if L != 32768 // 2. If I instead run this as:
x = torch.nn.functional.pad(x, (0, _next_power_of_two(L) - L))
k = torch.nn.functional.pad(k, (0, _next_power_of_two(L) - L))
y = fftconv_fn(x.contiguous().to(dtype=fftconv_fn.dtype), k.float()).to(dtype=x.dtype)[..., :L]
The problem does not appear and everything works fine.
My intuition says that this has something to do with the fact that implicitly padded elements might not exist in memory, therefore leading to the error shown above. However, padding explicitly as shown in the solution is probably much more memory / compute expensive.
It also seems that - although I have not debugged this in detail- this error happens only for FlashFFTConv(seqlen=32768). More experimentation is needed to pinpoint this better though.
I hope this insight might help make the package more robust / efficient :)
Thank you !
Best,
David
The text was updated successfully, but these errors were encountered:
dwromero
changed the title
[bug] CUDA Runtime Error when padding is required
[bug] CUDA Runtime Error when implicit padding is required
Feb 9, 2024
Hi Dan & Hermann,
I noticed a bug in FlashFFTConv that appears when implict padding must be done.
Specifically, if I use a sequence of shape [B, N, L=14113] with a FlashFFTConv module parameterized with length 32768, as:
I get the following error:
After looking into the problem a bit more in detail, I found that this only happens if implicit padding is required, i.e., if
L != 32768 // 2
. If I instead run this as:The problem does not appear and everything works fine.
My intuition says that this has something to do with the fact that implicitly padded elements might not exist in memory, therefore leading to the error shown above. However, padding explicitly as shown in the solution is probably much more memory / compute expensive.
It also seems that - although I have not debugged this in detail- this error happens only for FlashFFTConv(seqlen=32768). More experimentation is needed to pinpoint this better though.
I hope this insight might help make the package more robust / efficient :)
Thank you !
Best,
David
The text was updated successfully, but these errors were encountered: