Skip to content

Commit

Permalink
Re-organize SLL ops, pt 6
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#722

- Re-organize `dense_jagged_cat_jagged_out`

Differential Revision: D68936183
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 31, 2025
1 parent 94be803 commit cca506b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 66 deletions.
9 changes: 4 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
array_jagged_bmm_jagged_out,
dense_jagged_cat_jagged_out,
# dense_jagged_cat_jagged_out,
jagged2_to_padded_dense,
# jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_jagged_bmm_jagged_out,
triton_jagged_self_substraction_jagged_out,
Expand Down Expand Up @@ -269,9 +268,9 @@

# pyre-ignore[5]
sll_gpu_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
# "sll_dense_jagged_cat_jagged_out": {
# "CUDA": dense_jagged_cat_jagged_out,
# },
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
Expand Down
7 changes: 7 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

# pyre-strict

from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
dense_jagged_cat_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
jagged_dense_bmm,
jagged_jagged_bmm,
Expand Down Expand Up @@ -43,6 +47,9 @@

# pyre-ignore[5]
op_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_dense_bmm": {
"CUDA": jagged_dense_bmm,
"AutogradCUDA": jagged_dense_bmm,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import triton
import triton.language as tl


@triton.jit
def dense_jagged_cat_jagged_out_kernel(
a_ptr, # dense
b_ptr, # jagged
c_ptr, # jagged
b_offsets_ptr,
c_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
b_start = tl.load(b_offsets_ptr + pid_batch)
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
c_start = b_start + pid_batch
N = b_end - b_start
N = tl.minimum(N, max_seq_len)

a = tl.load(a_ptr + pid_batch)
tl.store(c_ptr + c_start, a)

offs_k = tl.arange(0, BLOCK_SIZE)
for k in range(0, N, BLOCK_SIZE):
b_offset = k + offs_k
b_ptrs = b_ptr + b_start + b_offset
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)


def dense_jagged_cat_jagged_out(
a: torch.Tensor,
b: torch.Tensor,
b_offsets: torch.Tensor,
max_seq_len: int,
):
assert a.is_contiguous()
assert b.is_contiguous()
assert b_offsets.is_contiguous()
B = a.size(0)
BLOCK_SIZE = 128
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
c_offsets = torch.empty(
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
) # B + 1

dense_jagged_cat_jagged_out_kernel[(B,)](
a,
b,
c,
b_offsets,
c_offsets,
max_seq_len,
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
BLOCK_SIZE,
)

c_offsets[-1] = b_offsets[-1] + B

return c, c_offsets
61 changes: 0 additions & 61 deletions fbgemm_gpu/fbgemm_gpu/sll/triton_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,6 @@ def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
return x


@triton.jit
def dense_jagged_cat_jagged_out_kernel(
a_ptr, # dense
b_ptr, # jagged
c_ptr, # jagged
b_offsets_ptr,
c_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
b_start = tl.load(b_offsets_ptr + pid_batch)
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
c_start = b_start + pid_batch
N = b_end - b_start
N = tl.minimum(N, max_seq_len)

a = tl.load(a_ptr + pid_batch)
tl.store(c_ptr + c_start, a)

offs_k = tl.arange(0, BLOCK_SIZE)
for k in range(0, N, BLOCK_SIZE):
b_offset = k + offs_k
b_ptrs = b_ptr + b_start + b_offset
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)


@triton.jit
def jagged_self_substraction_jagged_out_kernel(
a_ptr, # jagged
Expand Down Expand Up @@ -423,38 +394,6 @@ def jagged_jagged_bmm_jagged_out_kernel(
tl.store(c_ptrs, c, mask=c_mask)


def dense_jagged_cat_jagged_out(
a: torch.Tensor,
b: torch.Tensor,
b_offsets: torch.Tensor,
max_seq_len: int,
):
assert a.is_contiguous()
assert b.is_contiguous()
assert b_offsets.is_contiguous()
B = a.size(0)
BLOCK_SIZE = 128
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
c_offsets = torch.empty(
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
) # B + 1

dense_jagged_cat_jagged_out_kernel[(B,)](
a,
b,
c,
b_offsets,
c_offsets,
max_seq_len,
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
BLOCK_SIZE,
)

c_offsets[-1] = b_offsets[-1] + B

return c, c_offsets


def triton_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
Expand Down

0 comments on commit cca506b

Please sign in to comment.