Skip to content

Commit

Permalink
Re-organize SLL ops, pt 6 (#3647)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3647

X-link: facebookresearch/FBGEMM#722

- Re-organize `dense_jagged_cat_jagged_out`

Differential Revision: D68936183
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 2, 2025
1 parent 7d0c24d commit 8a4e309
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 619 deletions.
15 changes: 0 additions & 15 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
array_jagged_bmm_jagged_out,
dense_jagged_cat_jagged_out,
jagged2_to_padded_dense,
# jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_jagged_bmm_jagged_out,
triton_jagged_self_substraction_jagged_out,
)

Expand Down Expand Up @@ -269,9 +265,6 @@

# pyre-ignore[5]
sll_gpu_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
Expand All @@ -283,14 +276,6 @@
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
23 changes: 23 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@

# pyre-strict

from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
dense_jagged_cat_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
jagged_dense_bmm,
jagged_jagged_bmm,
JaggedDenseBmm, # noqa F401
JaggedJaggedBmm, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
array_jagged_bmm_jagged_out,
ArrayJaggedBmmNopadding, # noqa F401
jagged_jagged_bmm_jagged_out,
JaggedJaggedBmmNoPadding, # noqa F401
triton_array_jagged_bmm_jagged_out, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
jagged_dense_elementwise_add,
JaggedDenseAdd, # noqa F401
Expand Down Expand Up @@ -43,6 +55,9 @@

# pyre-ignore[5]
op_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_dense_bmm": {
"CUDA": jagged_dense_bmm,
"AutogradCUDA": jagged_dense_bmm,
Expand All @@ -51,6 +66,14 @@
"CUDA": jagged_jagged_bmm,
"AutogradCUDA": jagged_jagged_bmm,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
"sll_jagged_softmax": {
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import triton
import triton.language as tl


@triton.jit
def dense_jagged_cat_jagged_out_kernel(
a_ptr, # dense
b_ptr, # jagged
c_ptr, # jagged
b_offsets_ptr,
c_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
b_start = tl.load(b_offsets_ptr + pid_batch)
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
c_start = b_start + pid_batch
N = b_end - b_start
N = tl.minimum(N, max_seq_len)

a = tl.load(a_ptr + pid_batch)
tl.store(c_ptr + c_start, a)

offs_k = tl.arange(0, BLOCK_SIZE)
for k in range(0, N, BLOCK_SIZE):
b_offset = k + offs_k
b_ptrs = b_ptr + b_start + b_offset
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)


def dense_jagged_cat_jagged_out(
a: torch.Tensor,
b: torch.Tensor,
b_offsets: torch.Tensor,
max_seq_len: int,
):
assert a.is_contiguous()
assert b.is_contiguous()
assert b_offsets.is_contiguous()
B = a.size(0)
BLOCK_SIZE = 128
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
c_offsets = torch.empty(
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
) # B + 1

dense_jagged_cat_jagged_out_kernel[(B,)](
a,
b,
c,
b_offsets,
c_offsets,
max_seq_len,
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
BLOCK_SIZE,
)

c_offsets[-1] = b_offsets[-1] + B

return c, c_offsets
Loading

0 comments on commit 8a4e309

Please sign in to comment.