Skip to content

Commit 2d0a83c

Browse files
authored
Merge pull request #579 from InfiniTensor/issue/578
issue/578: 添加 `empty_like` 并在 `silu` 的参数为 `None` 时默认调用 `ntops` 的 `silu`
1 parent cf6dff1 commit 2d0a83c

File tree

4 files changed

+57
-31
lines changed

4 files changed

+57
-31
lines changed

.github/workflows/ruff.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
name: Ruff
2+
on: [push, pull_request]
3+
jobs:
4+
ruff:
5+
runs-on: ubuntu-latest
6+
steps:
7+
- uses: actions/checkout@v4
8+
- uses: chartboost/ruff-action@v1
9+
with:
10+
src: './python/'
11+
- uses: chartboost/ruff-action@v1
12+
with:
13+
src: './python/'
14+
args: format --check

python/infinicore/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22

3+
import infinicore.nn as nn
34
from infinicore.device import device
45
from infinicore.dtype import (
56
bfloat16,
@@ -33,19 +34,21 @@
3334
from infinicore.tensor import (
3435
Tensor,
3536
empty,
37+
empty_like,
3638
from_blob,
3739
ones,
3840
strided_empty,
3941
strided_from_blob,
4042
zeros,
4143
)
4244

43-
from infinicore import nn as nn
44-
4545
__all__ = [
46+
# Modules.
47+
"nn",
4648
# Classes.
4749
"device",
4850
"dtype",
51+
"Tensor",
4952
# Data Types.
5053
"bfloat16",
5154
"bool",
@@ -75,6 +78,7 @@
7578
"matmul",
7679
"rearrange",
7780
"empty",
81+
"empty_like",
7882
"from_blob",
7983
"ones",
8084
"strided_empty",

python/infinicore/nn/functional.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,66 @@
11
import infinicore
22
from infinicore.lib import _infinicore
3+
from infinicore.tensor import Tensor
34

45
__all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"]
56

67

7-
def causal_softmax(
8-
input: infinicore.Tensor,
9-
out=None
10-
) -> infinicore.Tensor:
11-
r"""Apply a causal softmax function.
12-
"""
8+
def causal_softmax(input: Tensor, out=None) -> Tensor:
9+
r"""Apply a causal softmax function."""
1310

1411
if out is None:
15-
return infinicore.Tensor(_infinicore.causal_softmax(input._underlying))
12+
return Tensor(_infinicore.causal_softmax(input._underlying))
1613

1714
_infinicore.causal_softmax_(out._underlying, input._underlying)
1815

1916
return out
2017

2118

2219
def rms_norm(
23-
input: infinicore.Tensor,
24-
normalized_shape: list[int],
25-
weight: infinicore.Tensor,
26-
eps: float = 1e-5,
27-
out=None
28-
) -> infinicore.Tensor:
29-
r"""Apply Root Mean Square Layer Normalization.
30-
"""
31-
32-
assert normalized_shape == weight.shape, "normalized_shape does not match weight.shape."
20+
input: Tensor,
21+
normalized_shape: list[int],
22+
weight: Tensor,
23+
eps: float = 1e-5,
24+
*,
25+
out=None,
26+
) -> Tensor:
27+
r"""Apply Root Mean Square Layer Normalization."""
28+
29+
assert normalized_shape == weight.shape, (
30+
"normalized_shape does not match weight.shape."
31+
)
3332

3433
if out is None:
35-
return infinicore.Tensor(
36-
_infinicore.rms_norm(input._underlying, weight._underlying, eps)
37-
)
34+
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
3835

3936
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
4037

4138
return out
4239

4340

44-
def silu(input: infinicore.Tensor, inplace: bool = False, out=None) -> infinicore.Tensor:
45-
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.
46-
"""
41+
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
42+
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
43+
44+
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
45+
return infinicore.ntops.torch.silu(input, inplace=inplace)
4746

4847
if inplace:
4948
_infinicore.silu_(input._underlying, input._underlying)
5049
return input
5150

5251
if out is None:
53-
return infinicore.Tensor(_infinicore.silu(input._underlying))
52+
return Tensor(_infinicore.silu(input._underlying))
5453

5554
_infinicore.silu_(out._underlying, input._underlying)
5655

5756
return out
5857

5958

60-
def swiglu(input: infinicore.Tensor, other: infinicore.Tensor, out=None):
61-
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise.
62-
"""
59+
def swiglu(input: Tensor, other: Tensor, *, out=None):
60+
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
6361

6462
if out is None:
65-
return infinicore.Tensor(_infinicore.swiglu(input._underlying, other._underlying))
63+
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
6664

6765
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
6866

python/infinicore/tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def view(self, shape):
7777

7878
def debug(self, filename=None):
7979
"""Print tensor data or save to file for debugging
80-
80+
8181
Args:
8282
filename: Optional filename to save raw binary data. If None, prints to stdout.
8383
"""
@@ -93,6 +93,16 @@ def empty(size, *, dtype=None, device=None, pin_memory=False):
9393
)
9494

9595

96+
def empty_like(input, *, dtype=None, device=None):
97+
if dtype is None:
98+
dtype = input.dtype
99+
100+
if device is None:
101+
device = input.device
102+
103+
return empty(input.size(), dtype=dtype, device=device)
104+
105+
96106
def strided_empty(size, strides, *, dtype=None, device=None, pin_memory=False):
97107
return Tensor(
98108
_infinicore.strided_empty(

0 commit comments

Comments
 (0)