Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #10: Kernel Fusion using torch.jit #36

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Unit tests and some stubs:
sami-bg committed Nov 20, 2023
commit d9d08b95a2144b6b2070faa6b0435528b724341f
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_data_parallel.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def run_parallelize_a_transformers_and_inference(
assert torch.allclose(outputs["loss"], REF_LOSS)


def test_data_parllel_fused_bias_gelu_bias_dropout_fwd():
def test_data_parallel_fused_bias_gelu_bias_dropout_fwd():
# TODO
pass

2 changes: 1 addition & 1 deletion tests/nn/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -94,7 +94,7 @@ def get_leaf_modules(model):
assert torch.allclose(generated_tokens, REF_GENERATED_TOKENS)


def test_data_parllel_fused_bias_gelu_bias_dropout_fwd():
def test_tensor_parallel_fused_bias_gelu_bias_dropout_fwd():
# TODO
pass

28 changes: 28 additions & 0 deletions tests/nn/test_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import torch
from torch.nn import Dropout, GELU

from pipegoose.nn.fusion import FusedBiasDropout, FusedBiasGelu


def test_FusedBiasDropout():
dropout = FusedBiasDropout()
input = torch.randn(20, 16)
args = (0.5, False)
expected = Dropout(*args)(input)
actual = dropout(input)

assert actual.size() == expected.size()
assert torch.allclose(actual, expected)
assert FusedBiasDropout(*args).represents == Dropout


def test_FusedBiasGelu():
gelu = FusedBiasGelu()
input = torch.randn(20, 16)
expected = GELU()(input)
actual = gelu(input)

assert actual.size() == expected.size()
assert torch.allclose(actual, expected)
assert FusedBiasGelu().represents == GELU