diff --git a/examples/19_large_depthwise_conv2d_torch_extension/depthwise_conv2d_implicit_gemm.py b/examples/19_large_depthwise_conv2d_torch_extension/depthwise_conv2d_implicit_gemm.py index b4a1aa1c8a..f8c98b3344 100755 --- a/examples/19_large_depthwise_conv2d_torch_extension/depthwise_conv2d_implicit_gemm.py +++ b/examples/19_large_depthwise_conv2d_torch_extension/depthwise_conv2d_implicit_gemm.py @@ -21,6 +21,7 @@ def forward(ctx, x, w): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, grad): + grad = grad.contiguous() x, w = ctx.saved_tensors dx = _extension.backward_data_fp32(grad, w) dw = _extension.backward_filter_fp32(grad, x, w) @@ -37,6 +38,7 @@ def forward(ctx, x, w): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, grad): + grad = grad.contiguous() x, w = ctx.saved_tensors dx = _extension.backward_data_fp16(grad, w) dw = _extension.backward_filter_fp16(grad, x, w)