Skip to content

Commit

Permalink
Fold ops registration code, pt 3 (#3641)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#717

Pull Request resolved: #3641

- Fold out ops registration code in SLL ops

Reviewed By: sryap

Differential Revision: D68911389

fbshipit-source-id: 49649316b92f064063dee7079fe0c83a39c850ee
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 31, 2025
1 parent 3049ebe commit 4965f35
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 169 deletions.
183 changes: 86 additions & 97 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-strict

import torch

from fbgemm_gpu.sll.cpu_sll import ( # noqa F401
cpu_array_jagged_bmm_jagged_out,
cpu_dense_jagged_cat_jagged_out,
Expand All @@ -21,14 +23,14 @@
cpu_jagged_jagged_bmm_jagged_out,
cpu_jagged_self_substraction_jagged_out,
cpu_jagged_softmax,
meta_jagged_dense_elementwise_mul_jagged_out,
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.meta_sll import ( # noqa F401
meta_array_jagged_bmm_jagged_out,
meta_jagged2_softmax,
meta_jagged_dense_elementwise_mul_jagged_out,
meta_jagged_jagged_bmm_jagged_out,
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
Expand Down Expand Up @@ -208,144 +210,131 @@
"""
)

# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same function
# however, this is not ideal because in the inference case, we don't need the autograd forward
# to save the context because we don't need to do backward.
lib.register(
"sll_jagged_dense_bmm",
{
"CUDA": jagged_dense_bmm,
"AutogradCUDA": jagged_dense_bmm,
# NOTE: here we register the op for AutogradCUDA/CPU and CUDA/CPU with the same
# function however, this is not ideal because in the inference case, we don't
# need the autograd forward to save the context because we don't need to do
# backward.

# pyre-ignore[5]
sll_cpu_registrations = {
"sll_jagged_dense_bmm": {
"CPU": cpu_jagged_dense_bmm,
"AutogradCPU": cpu_jagged_dense_bmm,
},
)

lib.register(
"sll_jagged_jagged_bmm",
{
"CUDA": jagged_jagged_bmm,
"AutogradCUDA": jagged_jagged_bmm,
"sll_jagged_jagged_bmm": {
"CPU": cpu_jagged_jagged_bmm,
"AutogradCPU": cpu_jagged_jagged_bmm,
},
)

lib.register(
"sll_dense_jagged_cat_jagged_out",
{
"CUDA": dense_jagged_cat_jagged_out,
"sll_dense_jagged_cat_jagged_out": {
"CPU": cpu_dense_jagged_cat_jagged_out,
},
)

lib.register(
"sll_jagged_self_substraction_jagged_out",
{
"CUDA": triton_jagged_self_substraction_jagged_out,
"sll_jagged_self_substraction_jagged_out": {
"CPU": cpu_jagged_self_substraction_jagged_out,
"Meta": meta_jagged_self_substraction_jagged_out,
},
)

lib.register(
"sll_jagged2_to_padded_dense",
{
"CUDA": jagged2_to_padded_dense,
"AutogradCUDA": jagged2_to_padded_dense,
"sll_jagged2_to_padded_dense": {
"CPU": cpu_jagged2_to_padded_dense,
"AutogradCPU": cpu_jagged2_to_padded_dense,
},
)

lib.register(
"sll_jagged_dense_elementwise_mul_jagged_out",
{
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
"AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
"Meta": meta_jagged_dense_elementwise_mul_jagged_out,
},
)

lib.register(
"sll_jagged_softmax",
{
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
"sll_jagged_softmax": {
"CPU": cpu_jagged_softmax,
"AutogradCPU": cpu_jagged_softmax,
},
)

lib.register(
"sll_jagged2_softmax",
{
"CUDA": jagged2_softmax,
"AutogradCUDA": jagged2_softmax,
"sll_jagged2_softmax": {
"CPU": cpu_jagged2_softmax,
"AutogradCPU": cpu_jagged2_softmax,
"AutogradMeta": meta_jagged2_softmax,
},
)

lib.register(
"sll_array_jagged_bmm_jagged_out",
{
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
"sll_array_jagged_bmm_jagged_out": {
"CPU": cpu_array_jagged_bmm_jagged_out,
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
},
)

lib.register(
"sll_jagged_jagged_bmm_jagged_out",
{
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
"sll_jagged_jagged_bmm_jagged_out": {
"CPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
},
)

lib.register(
"sll_jagged_flash_attention_basic",
{
"CUDA": jagged_flash_attention_basic,
"AutogradCUDA": jagged_flash_attention_basic,
"sll_jagged_flash_attention_basic": {
"CPU": cpu_jagged_flash_attention_basic,
"AutogradCPU": cpu_jagged_flash_attention_basic,
},
)

lib.register(
"sll_jagged_dense_elementwise_add",
{
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
"sll_jagged_dense_elementwise_add": {
"CPU": cpu_jagged_dense_elementwise_add,
"AutogradCPU": cpu_jagged_dense_elementwise_add,
},
)

lib.register(
"sll_jagged_dense_flash_attention",
{
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
"sll_jagged_dense_flash_attention": {
"CPU": cpu_jagged_dense_flash_attention,
"AutogradCPU": cpu_jagged_dense_flash_attention,
},
)
}

lib.register(
"sll_multi_head_jagged_flash_attention",
{
# pyre-ignore[5]
sll_gpu_registrations = {
"sll_jagged_dense_bmm": {
"CUDA": jagged_dense_bmm,
"AutogradCUDA": jagged_dense_bmm,
},
"sll_jagged_jagged_bmm": {
"CUDA": jagged_jagged_bmm,
"AutogradCUDA": jagged_jagged_bmm,
},
"sll_dense_jagged_cat_jagged_out": {
"CUDA": dense_jagged_cat_jagged_out,
},
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged2_to_padded_dense": {
"CUDA": jagged2_to_padded_dense,
"AutogradCUDA": jagged2_to_padded_dense,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
"sll_jagged_softmax": {
"CUDA": jagged_softmax,
"AutogradCUDA": jagged_softmax,
},
"sll_jagged2_softmax": {
"CUDA": jagged2_softmax,
"AutogradCUDA": jagged2_softmax,
},
"sll_array_jagged_bmm_jagged_out": {
"CUDA": array_jagged_bmm_jagged_out,
"AutogradCUDA": array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CUDA": jagged_jagged_bmm_jagged_out,
"AutogradCUDA": jagged_jagged_bmm_jagged_out,
},
"sll_jagged_flash_attention_basic": {
"CUDA": jagged_flash_attention_basic,
"AutogradCUDA": jagged_flash_attention_basic,
},
"sll_jagged_dense_elementwise_add": {
"CUDA": jagged_dense_elementwise_add,
"AutogradCUDA": jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CUDA": jagged_dense_flash_attention,
"AutogradCUDA": jagged_dense_flash_attention,
},
"sll_multi_head_jagged_flash_attention": {
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
},
)
}

for op_name, dispatches in sll_cpu_registrations.items():
lib.register(op_name, dispatches)

if torch.cuda.is_available():
for op_name, dispatches in sll_gpu_registrations.items():
lib.register(op_name, dispatches)
72 changes: 0 additions & 72 deletions fbgemm_gpu/fbgemm_gpu/sll/cpu_sll.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,19 +213,6 @@ def cpu_jagged_self_substraction_jagged_out(
return jagged_B


def meta_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
offsets_b: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
return torch.empty(
[torch.library.get_ctx().new_dynamic_size()],
dtype=jagged_A.dtype,
device=jagged_A.device,
)


def cpu_jagged2_to_padded_dense(
values: torch.Tensor,
offsets: torch.Tensor,
Expand Down Expand Up @@ -352,65 +339,6 @@ def cpu_jagged_dense_elementwise_mul_jagged_out(
)


class MetaJaggedDenseElementwiseMul(torch.autograd.Function):
@staticmethod
# pyre-fixme
def forward(
ctx, # pyre-ignore [2]
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
ctx.max_seq_len = max_seq_len

ctx.save_for_backward(
x,
y,
x_seq_lengths,
x_offsets,
)

total_L = x.size(0)
jagged_C = torch.zeros((total_L), device=x.device, dtype=x.dtype)

return jagged_C

@staticmethod
# pyre-fixme
def backward(ctx, grad_output: torch.Tensor):
(
x,
y,
x_seq_lengths,
x_offsets,
) = ctx.saved_tensors

total_L = grad_output.size(0)
jagged_C = torch.zeros(
(total_L), device=grad_output.device, dtype=grad_output.dtype
)

return jagged_C, None, None, None, None


def meta_jagged_dense_elementwise_mul_jagged_out(
x: torch.Tensor,
y: torch.Tensor,
x_seq_lengths: torch.Tensor,
x_offsets: torch.Tensor,
max_seq_len: int,
) -> torch.Tensor:
return MetaJaggedDenseElementwiseMul.apply(
x,
y,
x_seq_lengths,
x_offsets,
max_seq_len,
)


class JaggedSoftmaxCPU(torch.autograd.Function):
@staticmethod
# pyre-fixme
Expand Down
Loading

0 comments on commit 4965f35

Please sign in to comment.