Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-organize SLL ops, pt 6 #3647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
array_jagged_bmm_jagged_out,
dense_jagged_cat_jagged_out,
jagged2_to_padded_dense,
# jagged_dense_bmm,
jagged_dense_elementwise_mul_jagged_out,
jagged_jagged_bmm_jagged_out,
triton_jagged_self_substraction_jagged_out,
)

Expand Down Expand Up @@ -269,9 +265,6 @@

# pyre-ignore[5]
sll_gpu_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
Expand All @@ -283,14 +276,6 @@
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
Expand Down
24 changes: 24 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,26 @@

# pyre-strict

from fbgemm_gpu.sll.triton.triton_dense_jagged_cat_jagged_out import (
dense_jagged_cat_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm import ( # noqa F401
jagged_dense_bmm,
jagged_jagged_bmm,
JaggedDenseBmm, # noqa F401
JaggedJaggedBmm, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_bmm_jagged_out import ( # noqa F401
array_jagged_bmm_jagged_out,
ArrayJaggedBmmNopadding, # noqa F401
jagged_jagged_bmm_jagged_out,
JaggedJaggedBmmNoPadding, # noqa F401
triton_array_jagged_bmm_jagged_out, # noqa F401
triton_jagged_jagged_bmm_jagged_out, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_add import ( # noqa F401
jagged_dense_elementwise_add,
JaggedDenseAdd, # noqa F401
Expand Down Expand Up @@ -43,6 +56,9 @@

# pyre-ignore[5]
op_registrations = {
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_dense_bmm": {
"CUDA": jagged_dense_bmm,
"AutogradCUDA": jagged_dense_bmm,
Expand All @@ -51,6 +67,14 @@
"CUDA": jagged_jagged_bmm,
"AutogradCUDA": jagged_jagged_bmm,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
"sll_jagged_softmax": {
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import triton
import triton.language as tl


@triton.jit
def dense_jagged_cat_jagged_out_kernel(
a_ptr, # dense
b_ptr, # jagged
c_ptr, # jagged
b_offsets_ptr,
c_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
b_start = tl.load(b_offsets_ptr + pid_batch)
b_end = tl.load(b_offsets_ptr + pid_batch + 1)
c_start = b_start + pid_batch
N = b_end - b_start
N = tl.minimum(N, max_seq_len)

a = tl.load(a_ptr + pid_batch)
tl.store(c_ptr + c_start, a)

offs_k = tl.arange(0, BLOCK_SIZE)
for k in range(0, N, BLOCK_SIZE):
b_offset = k + offs_k
b_ptrs = b_ptr + b_start + b_offset
b = tl.load(b_ptrs, mask=b_offset < N, other=0.0)
tl.store(c_ptr + c_start + 1 + b_offset, b, mask=b_offset < N)
tl.store(c_offsets_ptr + pid_batch, b_start + pid_batch)


def dense_jagged_cat_jagged_out(
a: torch.Tensor,
b: torch.Tensor,
b_offsets: torch.Tensor,
max_seq_len: int,
):
assert a.is_contiguous()
assert b.is_contiguous()
assert b_offsets.is_contiguous()
B = a.size(0)
BLOCK_SIZE = 128
c = torch.zeros(b.size(0) + a.size(0), dtype=a.dtype, device=a.device)
c_offsets = torch.empty(
b_offsets.size(0), dtype=b_offsets.dtype, device=b_offsets.device
) # B + 1

dense_jagged_cat_jagged_out_kernel[(B,)](
a,
b,
c,
b_offsets,
c_offsets,
max_seq_len,
# pyre-fixme[6]: For 7th argument expected `constexpr` but got `int`.
BLOCK_SIZE,
)

c_offsets[-1] = b_offsets[-1] + B

return c, c_offsets
Loading
Loading