Skip to content

Commit 8ff705b

Browse files
committed
init
0 parents  commit 8ff705b

20 files changed

+494
-0
lines changed

.clang-format

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# copied from https://github.com/microsoft/DeepSpeed
2+
3+
---
4+
# Refer to the following link for the explanation of each params:
5+
# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html
6+
Language: Cpp
7+
# BasedOnStyle: Google
8+
AccessModifierOffset: -4
9+
AlignAfterOpenBracket: Align
10+
AlignConsecutiveAssignments: false
11+
AlignConsecutiveDeclarations: false
12+
AlignEscapedNewlines: Left
13+
AlignOperands: true
14+
AlignTrailingComments: true
15+
AllowAllParametersOfDeclarationOnNextLine: false
16+
AllowShortBlocksOnASingleLine: true
17+
AllowShortCaseLabelsOnASingleLine: true
18+
AllowShortFunctionsOnASingleLine: All
19+
AllowShortIfStatementsOnASingleLine: true
20+
AllowShortLoopsOnASingleLine: true
21+
# This is deprecated
22+
AlwaysBreakAfterDefinitionReturnType: None
23+
AlwaysBreakAfterReturnType: None
24+
AlwaysBreakBeforeMultilineStrings: true
25+
AlwaysBreakTemplateDeclarations: true
26+
BinPackArguments: false
27+
BinPackParameters: false
28+
BraceWrapping:
29+
AfterClass: false
30+
AfterControlStatement: false
31+
AfterEnum: false
32+
AfterFunction: false
33+
AfterNamespace: false
34+
AfterObjCDeclaration: false
35+
AfterStruct: false
36+
AfterUnion: false
37+
AfterExternBlock: false
38+
BeforeCatch: false
39+
BeforeElse: false
40+
IndentBraces: false
41+
# disabling the below splits, else, they'll just add to the vertical length of source files!
42+
SplitEmptyFunction: false
43+
SplitEmptyRecord: false
44+
SplitEmptyNamespace: false
45+
BreakBeforeBinaryOperators: None
46+
BreakBeforeBraces: WebKit
47+
BreakBeforeInheritanceComma: false
48+
BreakInheritanceList: BeforeColon
49+
BreakBeforeTernaryOperators: true
50+
BreakConstructorInitializersBeforeComma: false
51+
BreakConstructorInitializers: BeforeColon
52+
BreakAfterJavaFieldAnnotations: false
53+
BreakStringLiterals: true
54+
ColumnLimit: 119
55+
CommentPragmas: '^ IWYU pragma:'
56+
CompactNamespaces: false
57+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
58+
# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform
59+
ConstructorInitializerIndentWidth: 4
60+
ContinuationIndentWidth: 4
61+
Cpp11BracedListStyle: true
62+
DerivePointerAlignment: false
63+
DisableFormat: false
64+
ExperimentalAutoDetectBinPacking: false
65+
FixNamespaceComments: true
66+
ForEachMacros:
67+
- foreach
68+
- Q_FOREACH
69+
- BOOST_FOREACH
70+
IncludeBlocks: Preserve
71+
IncludeCategories:
72+
- Regex: '^<ext/.*\.h>'
73+
Priority: 2
74+
- Regex: '^<.*\.h>'
75+
Priority: 1
76+
- Regex: '^<.*'
77+
Priority: 2
78+
- Regex: '.*'
79+
Priority: 3
80+
IncludeIsMainRegex: '([-_](test|unittest))?$'
81+
IndentCaseLabels: true
82+
IndentPPDirectives: None
83+
IndentWidth: 4
84+
IndentWrappedFunctionNames: false
85+
JavaScriptQuotes: Leave
86+
JavaScriptWrapImports: true
87+
KeepEmptyLinesAtTheStartOfBlocks: false
88+
MacroBlockBegin: ''
89+
MacroBlockEnd: ''
90+
MaxEmptyLinesToKeep: 1
91+
NamespaceIndentation: None
92+
ObjCBinPackProtocolList: Never
93+
ObjCBlockIndentWidth: 4
94+
ObjCSpaceAfterProperty: false
95+
ObjCSpaceBeforeProtocolList: true
96+
PenaltyBreakAssignment: 4
97+
PenaltyBreakBeforeFirstCallParameter: 1
98+
PenaltyBreakComment: 300
99+
PenaltyBreakFirstLessLess: 120
100+
PenaltyBreakString: 1000
101+
PenaltyBreakTemplateDeclaration: 10
102+
PenaltyExcessCharacter: 1000000
103+
PenaltyReturnTypeOnItsOwnLine: 200
104+
PointerAlignment: Left
105+
RawStringFormats:
106+
- Language: Cpp
107+
Delimiters:
108+
- cc
109+
- CC
110+
- cpp
111+
- Cpp
112+
- CPP
113+
- 'c++'
114+
- 'C++'
115+
CanonicalDelimiter: ''
116+
- Language: TextProto
117+
Delimiters:
118+
- pb
119+
- PB
120+
- proto
121+
- PROTO
122+
EnclosingFunctions:
123+
- EqualsProto
124+
- EquivToProto
125+
- PARSE_PARTIAL_TEXT_PROTO
126+
- PARSE_TEST_PROTO
127+
- PARSE_TEXT_PROTO
128+
- ParseTextOrDie
129+
- ParseTextProtoOrDie
130+
CanonicalDelimiter: ''
131+
BasedOnStyle: google
132+
# Enabling comment reflow causes doxygen comments to be messed up in their formats!
133+
ReflowComments: true
134+
SortIncludes: true
135+
SortUsingDeclarations: true
136+
SpaceAfterCStyleCast: false
137+
SpaceAfterTemplateKeyword: true
138+
SpaceBeforeAssignmentOperators: true
139+
SpaceBeforeCpp11BracedList: false
140+
SpaceBeforeCtorInitializerColon: true
141+
SpaceBeforeInheritanceColon: true
142+
SpaceBeforeParens: ControlStatements
143+
SpaceBeforeRangeBasedForLoopColon: true
144+
SpaceInEmptyParentheses: false
145+
SpacesBeforeTrailingComments: 1
146+
SpacesInAngles: false
147+
SpacesInContainerLiterals: true
148+
SpacesInCStyleCastParentheses: false
149+
SpacesInParentheses: false
150+
SpacesInSquareBrackets: false
151+
Standard: Cpp11
152+
StatementMacros:
153+
- Q_UNUSED
154+
- QT_REQUIRE_VERSION
155+
# Be consistent with indent-width, even for people who use tab for indentation!
156+
TabWidth: 4
157+
UseTab: Never

.gitignore

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
__pycache__
2+
.pytest_cache
3+
.vscode
4+
*.so
5+
*.pyc
6+
/appwrapper.yaml
7+
*.egg-info/
8+
build/

.pre-commit-config.yaml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
repos:
2+
- repo: https://github.com/pycqa/isort
3+
rev: 5.12.0
4+
hooks:
5+
- id: isort
6+
name: isort (python)
7+
- repo: https://github.com/psf/black
8+
rev: 23.10.0
9+
hooks:
10+
- id: black
11+
args: [--line-length=119,--target-version=py36]
12+
- repo: https://github.com/pre-commit/mirrors-clang-format
13+
rev: v17.0.3
14+
hooks:
15+
- id: clang-format
16+
types_or: [c++, c, cuda]
17+
args: [-style=file:.clang-format]

Makefile

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
install:
2+
pip install .
3+
4+
install-dev:
5+
pip install -e .
6+
7+
test:
8+
pytest tests
9+
10+
style:
11+
pre-commit run --all-files

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Efficient GPU kernels written in both CUDA and Triton
2+
3+
<p align="center">
4+
<img src="assets/logo.jpeg" width="300px" height="300px">
5+
</p>

assets/logo.jpeg

326 KB
Loading

kernels/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .utils import compile_helpers
2+
from .vector_addition import (
3+
VectorAddition_CUDA,
4+
VectorAddition_PyTorch,
5+
VectorAddition_Triton,
6+
vector_addition_cuda,
7+
vector_addition_pytorch,
8+
vector_addition_triton,
9+
)
10+
11+
12+
compile_helpers()

kernels/utils.h

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#include <torch/extension.h>
2+
3+
// C++ interface
4+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " is not on CUDA device")
5+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " is not a contiguous tensor")
6+
7+
#define CHECK_INPUT(x) \
8+
CHECK_CUDA(x); \
9+
CHECK_CONTIGUOUS(x);

kernels/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os
2+
3+
from torch.utils.cpp_extension import load as load_cpp_extension
4+
5+
6+
def compile_helpers() -> None:
7+
load_cpp_extension(
8+
"vector_addition_cuda",
9+
sources=[
10+
os.path.join(os.path.dirname(__file__), "vector_addition/cuda_kernel/vector_addition.cpp"),
11+
os.path.join(os.path.dirname(__file__), "vector_addition/cuda_kernel/vector_addition.cu"),
12+
],
13+
with_cuda=True,
14+
extra_cflags=["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"],
15+
verbose=True,
16+
)

kernels/vector_addition/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .cuda_kernel import VectorAddition_CUDA, vector_addition_cuda
2+
from .pytorch import VectorAddition_PyTorch, vector_addition_pytorch
3+
from .triton_kernel import VectorAddition_Triton, vector_addition_triton
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class _VectorAddition_CUDA(torch.autograd.Function):
8+
def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9+
import vector_addition_cuda
10+
11+
return vector_addition_cuda.vector_addition_forward(x, y)
12+
13+
def backward(ctx, output_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
14+
return output_grad, output_grad
15+
16+
17+
def vector_addition_cuda(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
18+
return _VectorAddition_CUDA.apply(x, y)
19+
20+
21+
class VectorAddition_CUDA(nn.Module):
22+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
23+
return vector_addition_cuda(x, y)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <torch/extension.h>
2+
#include "../../utils.h"
3+
4+
// CUDA kernel declarations
5+
torch::Tensor vector_addition_forward_kernel_launcher(torch::Tensor x, torch::Tensor y, const int BLOCK_SIZE);
6+
7+
torch::Tensor vector_addition_forward(torch::Tensor x, torch::Tensor y)
8+
{
9+
CHECK_INPUT(x);
10+
CHECK_INPUT(y);
11+
12+
TORCH_CHECK(x.dim() == 1, "tensor should be 1 dimensional")
13+
TORCH_CHECK(y.dim() == 1, "tensor should be 1 dimensional")
14+
15+
TORCH_CHECK(x.numel() == y.numel(), "both tensors should have same number of elements");
16+
TORCH_CHECK(x.type() == y.type(), "both tensors should have same dtype");
17+
18+
return vector_addition_forward_kernel_launcher(x, y, 1024);
19+
}
20+
21+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
22+
{
23+
m.def("vector_addition_forward", &vector_addition_forward, "Vector addition forward (CUDA)");
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <cuda.h>
2+
#include <cuda_runtime.h>
3+
#include <torch/extension.h>
4+
5+
template <typename scalar_t>
6+
__global__ void vector_addition_forward_kernel(const scalar_t* x,
7+
const scalar_t* y,
8+
scalar_t* output,
9+
const int num_elements)
10+
{
11+
int index = blockIdx.x * blockDim.x + threadIdx.x;
12+
if (index < num_elements) output[index] = x[index] + y[index];
13+
}
14+
15+
torch::Tensor vector_addition_forward_kernel_launcher(torch::Tensor x, torch::Tensor y, const int BLOCK_SIZE)
16+
{
17+
int num_elements = x.numel();
18+
torch::Tensor output = torch::empty_like(x);
19+
20+
int blocks = (int)ceil((float)num_elements / BLOCK_SIZE);
21+
22+
if (at::isReducedFloatingType(x.scalar_type())) {
23+
AT_DISPATCH_REDUCED_FLOATING_TYPES(
24+
x.scalar_type(), "vector_addition_forward_kernel", ([&] {
25+
vector_addition_forward_kernel<scalar_t><<<blocks, BLOCK_SIZE>>>(
26+
x.data<scalar_t>(), y.data<scalar_t>(), output.data<scalar_t>(), num_elements);
27+
}));
28+
} else {
29+
AT_DISPATCH_FLOATING_TYPES(
30+
x.scalar_type(), "vector_addition_forward_kernel", ([&] {
31+
vector_addition_forward_kernel<scalar_t><<<blocks, BLOCK_SIZE>>>(
32+
x.data<scalar_t>(), y.data<scalar_t>(), output.data<scalar_t>(), num_elements);
33+
}));
34+
}
35+
36+
return output;
37+
}

kernels/vector_addition/pytorch.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def vector_addition_pytorch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6+
return x + y
7+
8+
9+
class VectorAddition_PyTorch(nn.Module):
10+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
11+
return vector_addition_pytorch(x, y)
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
import triton
6+
import triton.language as tl
7+
8+
9+
@triton.jit
10+
def _vector_addition_forward(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
11+
pid = tl.program_id(axis=0)
12+
13+
block_start = pid * BLOCK_SIZE
14+
block_indices = block_start + tl.arange(0, BLOCK_SIZE)
15+
16+
mask = block_indices < num_elements
17+
18+
x = tl.load(x_ptr + block_indices, mask=mask)
19+
y = tl.load(y_ptr + block_indices, mask=mask)
20+
21+
output = x + y
22+
23+
tl.store(output_ptr + block_indices, output, mask=mask)
24+
25+
26+
class _VectorAddition_Triton(torch.autograd.Function):
27+
def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
28+
assert x.dim() == 1
29+
output = torch.empty_like(x)
30+
31+
num_elements = x.numel()
32+
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
33+
34+
_vector_addition_forward[grid](x, y, output, num_elements, BLOCK_SIZE=1024)
35+
36+
return output
37+
38+
def backward(ctx, output_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
39+
return output_grad, output_grad
40+
41+
42+
def vector_addition_triton(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
43+
return _VectorAddition_Triton.apply(x, y)
44+
45+
46+
class VectorAddition_Triton(nn.Module):
47+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
48+
return vector_addition_triton(x, y)

0 commit comments

Comments
 (0)