diff --git a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py index d79de9a23..271f5fa31 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/__init__.py @@ -34,7 +34,6 @@ ) from fbgemm_gpu.sll.triton_sll import ( # noqa F401 - jagged2_to_padded_dense, jagged_dense_elementwise_mul_jagged_out, triton_jagged_self_substraction_jagged_out, ) @@ -268,10 +267,6 @@ "sll_jagged_self_substraction_jagged_out": { "CUDA": triton_jagged_self_substraction_jagged_out, }, - "sll_jagged2_to_padded_dense": { - "CUDA": jagged2_to_padded_dense, - "AutogradCUDA": jagged2_to_padded_dense, - }, "sll_jagged_dense_elementwise_mul_jagged_out": { "CUDA": jagged_dense_elementwise_mul_jagged_out, "AutogradCUDA": jagged_dense_elementwise_mul_jagged_out, diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py index 20dd6dc3e..ffac62966 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py @@ -11,6 +11,11 @@ dense_jagged_cat_jagged_out, ) +from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401 + jagged2_to_padded_dense, + Jagged2ToPaddedDense, # noqa F401 +) + from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401 jagged_dense_bmm, jagged_jagged_bmm, @@ -66,6 +71,10 @@ "CUDA": jagged_jagged_bmm, "AutogradCUDA": jagged_jagged_bmm, }, + "sll_jagged2_to_padded_dense": { + "CUDA": jagged2_to_padded_dense, + "AutogradCUDA": jagged2_to_padded_dense, + }, "sll_array_jagged_bmm_jagged_out": { "CUDA": array_jagged_bmm_jagged_out, "AutogradCUDA": array_jagged_bmm_jagged_out, diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py new file mode 100644 index 000000000..dfeabbce3 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from .common import expect_contiguous + + +@triton.jit +def jagged2_to_padded_dense_kernel( + x_ptr, + lengths_ptr, + offsets_ptr, + output_dense_ptr, + stride_b, + stride_m, + stride_n, + max_length, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_batch = tl.program_id(2) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + begin = tl.load(offsets_ptr + pid_batch) + seqlen = tl.load(lengths_ptr + pid_batch) + + seqlen = tl.minimum(seqlen, max_length) + if seqlen == 0: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + x_ptrs = x_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :] + x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))) + + out_ptrs = ( + output_dense_ptr + + pid_batch * stride_b + + offs_m[:, None] * stride_m + + offs_n[None, :] * stride_n + ) + tl.store( + out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)) + ) + + +@triton.jit +def padded_dense_to_jagged2_kernel( + x_ptr, + lengths_ptr, + offsets_ptr, + output_jagged_ptr, + stride_b, + stride_m, + stride_n, + max_length, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_batch = tl.program_id(2) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + begin = tl.load(offsets_ptr + pid_batch) + # end = tl.load(offsets_ptr + pid_batch + 1) + seqlen = tl.load(lengths_ptr + pid_batch) + + seqlen = tl.minimum(seqlen, max_length) + + if seqlen == 0: + return + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + x_ptrs = ( + x_ptr + + pid_batch * stride_b + + offs_m[:, None] * stride_m + + offs_n[None, :] * stride_n + ) + x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))) + out_ptrs = output_jagged_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :] + tl.store( + out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)) + ) + + +def jagged2_to_padded_dense_fwd( + values: torch.Tensor, + lengths: torch.Tensor, + offsets: torch.Tensor, + max_length: int, + padding_value: float, +) -> torch.Tensor: + B = offsets.size(0) - 1 + + output_dense = torch.full( + (B, max_length, max_length), + padding_value, + dtype=values.dtype, + device=values.device, + ) + BLOCK_M = 32 + BLOCK_N = 32 + num_blocks_m = triton.cdiv(max_length, BLOCK_M) + num_blocks_n = triton.cdiv(max_length, BLOCK_N) + grid = (num_blocks_m, num_blocks_n, B) + + jagged2_to_padded_dense_kernel[grid]( + values, + lengths, + offsets, + output_dense, + output_dense.stride(0), + output_dense.stride(1), + output_dense.stride(2), + max_length, + # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. + BLOCK_M, + # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. + BLOCK_N, + ) + + return output_dense + + +def padded_dense_to_jagged2_fwd( + values: torch.Tensor, + lengths: torch.Tensor, + offsets: torch.Tensor, + max_length: int, +) -> torch.Tensor: + B = values.size(0) + output_jagged = torch.empty( + int(offsets[-1]), dtype=values.dtype, device=values.device + ) + BLOCK_M = 32 + BLOCK_N = 32 + num_blocks_m = triton.cdiv(max_length, BLOCK_M) + num_blocks_n = triton.cdiv(max_length, BLOCK_N) + grid = (num_blocks_m, num_blocks_n, B) + + padded_dense_to_jagged2_kernel[grid]( + values, + lengths, + offsets, + output_jagged, + values.stride(0), + values.stride(1), + values.stride(2), + max_length, + # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. + BLOCK_M, + # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. + BLOCK_N, + ) + + return output_jagged + + +class Jagged2ToPaddedDense(torch.autograd.Function): + @staticmethod + # pyre-fixme + def forward( + ctx, + values: torch.Tensor, + offsets: torch.Tensor, + max_length: int, + padding_value: float, + ) -> torch.Tensor: + lengths_square = offsets[1:] - offsets[0:-1:1] + lengths = torch.sqrt(lengths_square).to(torch.int32) + + ctx.max_length = max_length + ctx.save_for_backward(lengths, offsets) + + output = jagged2_to_padded_dense_fwd( + values, lengths, offsets, max_length, padding_value + ) + return output + + @staticmethod + # pyre-fixme + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None, None, None]: + max_length = ctx.max_length + (lengths, offsets) = ctx.saved_tensors + grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length) + return (grad_in, None, None, None) + + +def jagged2_to_padded_dense( + values: torch.Tensor, + offsets: torch.Tensor, + max_length: int, + padding_value: float = 0.0, +) -> torch.Tensor: + """ + values: jagged tensor with size [sum(Ni * Ni)] + offsets: offsets for jagged tensor, with size [B + 1] + max_length: maximum sequence length in the batch + padding_value: value to use for padding + return padded dense tensor of size [B, N, N] + """ + values = expect_contiguous(values) + offsets = expect_contiguous(offsets) + + return Jagged2ToPaddedDense.apply(values, offsets, max_length, padding_value) diff --git a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py index e4e2dbc5f..a014c2dfb 100644 --- a/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py +++ b/fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py @@ -6,8 +6,6 @@ # pyre-unsafe -from typing import Tuple - import torch import triton import triton.language as tl @@ -35,13 +33,6 @@ def next_power_of_two(N: int) -> int: return 32 -def expect_contiguous(x: torch.Tensor) -> torch.Tensor: - if not x.is_contiguous(): - return x.contiguous() - else: - return x - - @triton.jit def jagged_self_substraction_jagged_out_kernel( a_ptr, # jagged @@ -75,89 +66,6 @@ def jagged_self_substraction_jagged_out_kernel( tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask) -@triton.jit -def jagged2_to_padded_dense_kernel( - x_ptr, - lengths_ptr, - offsets_ptr, - output_dense_ptr, - stride_b, - stride_m, - stride_n, - max_length, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - pid_batch = tl.program_id(2) - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - begin = tl.load(offsets_ptr + pid_batch) - seqlen = tl.load(lengths_ptr + pid_batch) - - seqlen = tl.minimum(seqlen, max_length) - if seqlen == 0: - return - - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - x_ptrs = x_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :] - x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))) - - out_ptrs = ( - output_dense_ptr - + pid_batch * stride_b - + offs_m[:, None] * stride_m - + offs_n[None, :] * stride_n - ) - tl.store( - out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)) - ) - - -@triton.jit -def padded_dense_to_jagged2_kernel( - x_ptr, - lengths_ptr, - offsets_ptr, - output_jagged_ptr, - stride_b, - stride_m, - stride_n, - max_length, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - pid_batch = tl.program_id(2) - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - - begin = tl.load(offsets_ptr + pid_batch) - # end = tl.load(offsets_ptr + pid_batch + 1) - seqlen = tl.load(lengths_ptr + pid_batch) - - seqlen = tl.minimum(seqlen, max_length) - - if seqlen == 0: - return - - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - x_ptrs = ( - x_ptr - + pid_batch * stride_b - + offs_m[:, None] * stride_m - + offs_n[None, :] * stride_n - ) - x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))) - out_ptrs = output_jagged_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :] - tl.store( - out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)) - ) - - @triton.jit def jagged_dense_elementwise_mul_jagged_out_kernel( a_ptr, # 1d jagged @@ -242,79 +150,6 @@ def triton_jagged_self_substraction_jagged_out( return jagged_B -def jagged2_to_padded_dense_fwd( - values: torch.Tensor, - lengths: torch.Tensor, - offsets: torch.Tensor, - max_length: int, - padding_value: float, -) -> torch.Tensor: - B = offsets.size(0) - 1 - - output_dense = torch.full( - (B, max_length, max_length), - padding_value, - dtype=values.dtype, - device=values.device, - ) - BLOCK_M = 32 - BLOCK_N = 32 - num_blocks_m = triton.cdiv(max_length, BLOCK_M) - num_blocks_n = triton.cdiv(max_length, BLOCK_N) - grid = (num_blocks_m, num_blocks_n, B) - - jagged2_to_padded_dense_kernel[grid]( - values, - lengths, - offsets, - output_dense, - output_dense.stride(0), - output_dense.stride(1), - output_dense.stride(2), - max_length, - # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. - BLOCK_M, - # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. - BLOCK_N, - ) - - return output_dense - - -def padded_dense_to_jagged2_fwd( - values: torch.Tensor, - lengths: torch.Tensor, - offsets: torch.Tensor, - max_length: int, -) -> torch.Tensor: - B = values.size(0) - output_jagged = torch.empty( - int(offsets[-1]), dtype=values.dtype, device=values.device - ) - BLOCK_M = 32 - BLOCK_N = 32 - num_blocks_m = triton.cdiv(max_length, BLOCK_M) - num_blocks_n = triton.cdiv(max_length, BLOCK_N) - grid = (num_blocks_m, num_blocks_n, B) - - padded_dense_to_jagged2_kernel[grid]( - values, - lengths, - offsets, - output_jagged, - values.stride(0), - values.stride(1), - values.stride(2), - max_length, - # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. - BLOCK_M, - # pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`. - BLOCK_N, - ) - - return output_jagged - - def triton_jagged_dense_elementwise_mul_jagged_out( jagged_A, dense_B, @@ -349,38 +184,6 @@ def triton_jagged_dense_elementwise_mul_jagged_out( return jagged_C -class Jagged2ToPaddedDense(torch.autograd.Function): - @staticmethod - # pyre-fixme - def forward( - ctx, - values: torch.Tensor, - offsets: torch.Tensor, - max_length: int, - padding_value: float, - ) -> torch.Tensor: - lengths_square = offsets[1:] - offsets[0:-1:1] - lengths = torch.sqrt(lengths_square).to(torch.int32) - - ctx.max_length = max_length - ctx.save_for_backward(lengths, offsets) - - output = jagged2_to_padded_dense_fwd( - values, lengths, offsets, max_length, padding_value - ) - return output - - @staticmethod - # pyre-fixme - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[torch.Tensor, None, None, None]: - max_length = ctx.max_length - (lengths, offsets) = ctx.saved_tensors - grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length) - return (grad_in, None, None, None) - - class JaggedDenseElementwiseMul(torch.autograd.Function): """ Compute elementwise multiplication between jagged tensor and dense tensor. @@ -438,25 +241,6 @@ def backward(ctx, grad_output: torch.Tensor): return grad_x, None, None, None, None -def jagged2_to_padded_dense( - values: torch.Tensor, - offsets: torch.Tensor, - max_length: int, - padding_value: float = 0.0, -) -> torch.Tensor: - """ - values: jagged tensor with size [sum(Ni * Ni)] - offsets: offsets for jagged tensor, with size [B + 1] - max_length: maximum sequence length in the batch - padding_value: value to use for padding - return padded dense tensor of size [B, N, N] - """ - values = expect_contiguous(values) - offsets = expect_contiguous(offsets) - - return Jagged2ToPaddedDense.apply(values, offsets, max_length, padding_value) - - def jagged_dense_elementwise_mul_jagged_out( x: torch.Tensor, y: torch.Tensor, diff --git a/fbgemm_gpu/test/sll/jagged2_to_padded_dense_test.py b/fbgemm_gpu/test/sll/jagged2_to_padded_dense_test.py new file mode 100644 index 000000000..e4aae5a3f --- /dev/null +++ b/fbgemm_gpu/test/sll/jagged2_to_padded_dense_test.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import unittest + +import fbgemm_gpu.sll # noqa F401 +import torch +from hypothesis import given, settings, strategies as st + +from .common import open_source # noqa + +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, running_on_rocm +else: + from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_rocm + + +class Jagged2ToPaddedDenseTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + @unittest.skipIf(*running_on_rocm) + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @given( + B=st.integers(1, 10), + max_L=st.integers(1, 100), + device_type=st.sampled_from(["cpu", "cuda"]), + ) + @settings(deadline=None) + def test_jagged2_to_padded_dense( + self, + B: int, + max_L: int, + device_type: str, + ) -> None: + device = torch.device(device_type) + + lengths = torch.randint(1, max_L + 1, (B,), device=device) + lengths_square = lengths * lengths + offsets = torch.cat( + [ + torch.tensor([0], device=device, dtype=torch.int), + lengths_square.cumsum(dim=0), + ], + dim=0, + ) + + x = torch.rand( + int(lengths_square.sum().item()), + requires_grad=True, + device=device, + ) + + def ref_jagged2_to_padded_dense( + x: torch.Tensor, offsets: torch.Tensor, max_L: int, padding_value: float + ) -> torch.Tensor: + B = offsets.size(0) - 1 + dense_output = torch.full( + (B, max_L, max_L), + padding_value, + dtype=x.dtype, + device=x.device, + ) + for b in range(B): + begin = offsets[b] + end = offsets[b + 1] + Ni = int(torch.sqrt(end - begin)) + if Ni == 0: + continue + dense_output[b, 0:Ni, 0:Ni] = x[begin:end].view(Ni, Ni) + + return dense_output + + x_clone = ( + x.detach().clone().requires_grad_() + if x.requires_grad + else x.detach().clone() + ) + padding_value = 0.0 + ref_out = ref_jagged2_to_padded_dense(x, offsets, max_L, padding_value) + test_out = torch.ops.fbgemm.sll_jagged2_to_padded_dense( + x_clone, offsets, max_L, padding_value + ) + assert torch.allclose(ref_out, test_out) + + # Backward pass + dout = torch.rand((B, max_L, max_L), dtype=x.dtype, device=x.device) * 0.1 + test_out.backward(dout) + ref_out.backward(dout) + + assert x.grad is not None + assert x_clone.grad is not None + assert torch.allclose(x.grad, x_clone.grad) diff --git a/fbgemm_gpu/test/sll/triton_sll_test.py b/fbgemm_gpu/test/sll/triton_sll_test.py index d9900243f..1b0b1c461 100644 --- a/fbgemm_gpu/test/sll/triton_sll_test.py +++ b/fbgemm_gpu/test/sll/triton_sll_test.py @@ -149,77 +149,3 @@ def model( ref = a[:-1].unsqueeze(1) - a[1:].unsqueeze(0) assert torch.equal(result[offsets_b[i] : offsets_b[i + 1]], ref.flatten()) - - @unittest.skipIf(*gpu_unavailable) - @unittest.skipIf(*running_on_rocm) - # pyre-fixme[56]: Pyre was not able to infer the type of argument - @given( - B=st.integers(1, 10), - max_L=st.integers(1, 100), - device_type=st.sampled_from(["cpu", "cuda"]), - ) - @settings(deadline=None) - def test_jagged2_to_padded_dense( - self, - B: int, - max_L: int, - device_type: str, - ) -> None: - device = torch.device(device_type) - - lengths = torch.randint(1, max_L + 1, (B,), device=device) - lengths_square = lengths * lengths - offsets = torch.cat( - [ - torch.tensor([0], device=device, dtype=torch.int), - lengths_square.cumsum(dim=0), - ], - dim=0, - ) - - x = torch.rand( - int(lengths_square.sum().item()), - requires_grad=True, - device=device, - ) - - def ref_jagged2_to_padded_dense( - x: torch.Tensor, offsets: torch.Tensor, max_L: int, padding_value: float - ) -> torch.Tensor: - B = offsets.size(0) - 1 - dense_output = torch.full( - (B, max_L, max_L), - padding_value, - dtype=x.dtype, - device=x.device, - ) - for b in range(B): - begin = offsets[b] - end = offsets[b + 1] - Ni = int(torch.sqrt(end - begin)) - if Ni == 0: - continue - dense_output[b, 0:Ni, 0:Ni] = x[begin:end].view(Ni, Ni) - - return dense_output - - x_clone = ( - x.detach().clone().requires_grad_() - if x.requires_grad - else x.detach().clone() - ) - padding_value = 0.0 - ref_out = ref_jagged2_to_padded_dense(x, offsets, max_L, padding_value) - test_out = torch.ops.fbgemm.sll_jagged2_to_padded_dense( - x_clone, offsets, max_L, padding_value - ) - assert torch.allclose(ref_out, test_out) - - # Backward pass - dout = torch.rand((B, max_L, max_L), dtype=x.dtype, device=x.device) * 0.1 - test_out.backward(dout) - ref_out.backward(dout) - - assert x.grad is not None - assert x_clone.grad is not None - assert torch.allclose(x.grad, x_clone.grad)