Skip to content

Commit

Permalink
Re-organize SLL ops, pt 7 (#3650)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3650

X-link: facebookresearch/FBGEMM#725

- Re-organize `jagged2_to_padded_dense`

Differential Revision: D68967316
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 2, 2025
1 parent 995e6a4 commit ccda7aa
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 295 deletions.
5 changes: 0 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
jagged2_to_padded_dense,
jagged_dense_elementwise_mul_jagged_out,
triton_jagged_self_substraction_jagged_out,
)
Expand Down Expand Up @@ -268,10 +267,6 @@
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged2_to_padded_dense": {
"CUDA": jagged2_to_padded_dense,
"AutogradCUDA": jagged2_to_padded_dense,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
Expand Down
9 changes: 9 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
dense_jagged_cat_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged2_to_padded_dense import ( # noqa F401
jagged2_to_padded_dense,
Jagged2ToPaddedDense, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
jagged_dense_bmm,
jagged_jagged_bmm,
Expand Down Expand Up @@ -66,6 +71,10 @@
"CUDA": jagged_jagged_bmm,
"AutogradCUDA": jagged_jagged_bmm,
},
"sll_jagged2_to_padded_dense": {
"CUDA": jagged2_to_padded_dense,
"AutogradCUDA": jagged2_to_padded_dense,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
Expand Down
222 changes: 222 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import Tuple

import torch
import triton
import triton.language as tl

from .common import expect_contiguous


@triton.jit
def jagged2_to_padded_dense_kernel(
x_ptr,
lengths_ptr,
offsets_ptr,
output_dense_ptr,
stride_b,
stride_m,
stride_n,
max_length,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_batch = tl.program_id(2)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

begin = tl.load(offsets_ptr + pid_batch)
seqlen = tl.load(lengths_ptr + pid_batch)

seqlen = tl.minimum(seqlen, max_length)
if seqlen == 0:
return

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

x_ptrs = x_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))

out_ptrs = (
output_dense_ptr
+ pid_batch * stride_b
+ offs_m[:, None] * stride_m
+ offs_n[None, :] * stride_n
)
tl.store(
out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
)


@triton.jit
def padded_dense_to_jagged2_kernel(
x_ptr,
lengths_ptr,
offsets_ptr,
output_jagged_ptr,
stride_b,
stride_m,
stride_n,
max_length,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_batch = tl.program_id(2)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

begin = tl.load(offsets_ptr + pid_batch)
# end = tl.load(offsets_ptr + pid_batch + 1)
seqlen = tl.load(lengths_ptr + pid_batch)

seqlen = tl.minimum(seqlen, max_length)

if seqlen == 0:
return

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

x_ptrs = (
x_ptr
+ pid_batch * stride_b
+ offs_m[:, None] * stride_m
+ offs_n[None, :] * stride_n
)
x = tl.load(x_ptrs, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen)))
out_ptrs = output_jagged_ptr + begin + offs_m[:, None] * seqlen + offs_n[None, :]
tl.store(
out_ptrs, x, mask=((offs_m[:, None] < seqlen) & (offs_n[None, :] < seqlen))
)


def jagged2_to_padded_dense_fwd(
values: torch.Tensor,
lengths: torch.Tensor,
offsets: torch.Tensor,
max_length: int,
padding_value: float,
) -> torch.Tensor:
B = offsets.size(0) - 1

output_dense = torch.full(
(B, max_length, max_length),
padding_value,
dtype=values.dtype,
device=values.device,
)
BLOCK_M = 32
BLOCK_N = 32
num_blocks_m = triton.cdiv(max_length, BLOCK_M)
num_blocks_n = triton.cdiv(max_length, BLOCK_N)
grid = (num_blocks_m, num_blocks_n, B)

jagged2_to_padded_dense_kernel[grid](
values,
lengths,
offsets,
output_dense,
output_dense.stride(0),
output_dense.stride(1),
output_dense.stride(2),
max_length,
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
BLOCK_M,
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
BLOCK_N,
)

return output_dense


def padded_dense_to_jagged2_fwd(
values: torch.Tensor,
lengths: torch.Tensor,
offsets: torch.Tensor,
max_length: int,
) -> torch.Tensor:
B = values.size(0)
output_jagged = torch.empty(
int(offsets[-1]), dtype=values.dtype, device=values.device
)
BLOCK_M = 32
BLOCK_N = 32
num_blocks_m = triton.cdiv(max_length, BLOCK_M)
num_blocks_n = triton.cdiv(max_length, BLOCK_N)
grid = (num_blocks_m, num_blocks_n, B)

padded_dense_to_jagged2_kernel[grid](
values,
lengths,
offsets,
output_jagged,
values.stride(0),
values.stride(1),
values.stride(2),
max_length,
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
BLOCK_M,
# pyre-fixme[6]: Incompatible parameter type [6]: expected `constexpr` but got `int`.
BLOCK_N,
)

return output_jagged


class Jagged2ToPaddedDense(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
ctx,
values: torch.Tensor,
offsets: torch.Tensor,
max_length: int,
padding_value: float,
) -> torch.Tensor:
lengths_square = offsets[1:] - offsets[0:-1:1]
lengths = torch.sqrt(lengths_square).to(torch.int32)

ctx.max_length = max_length
ctx.save_for_backward(lengths, offsets)

output = jagged2_to_padded_dense_fwd(
values, lengths, offsets, max_length, padding_value
)
return output

@staticmethod
# pyre-fixme
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
max_length = ctx.max_length
(lengths, offsets) = ctx.saved_tensors
grad_in = padded_dense_to_jagged2_fwd(grad_output, lengths, offsets, max_length)
return (grad_in, None, None, None)


def jagged2_to_padded_dense(
values: torch.Tensor,
offsets: torch.Tensor,
max_length: int,
padding_value: float = 0.0,
) -> torch.Tensor:
"""
values: jagged tensor with size [sum(Ni * Ni)]
offsets: offsets for jagged tensor, with size [B + 1]
max_length: maximum sequence length in the batch
padding_value: value to use for padding
return padded dense tensor of size [B, N, N]
"""
values = expect_contiguous(values)
offsets = expect_contiguous(offsets)

return Jagged2ToPaddedDense.apply(values, offsets, max_length, padding_value)
Loading

0 comments on commit ccda7aa

Please sign in to comment.