|
1 | 1 | import infinicore |
2 | 2 | from infinicore.lib import _infinicore |
| 3 | +from infinicore.tensor import Tensor |
3 | 4 |
|
4 | 5 | __all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"] |
5 | 6 |
|
6 | 7 |
|
7 | | -def causal_softmax( |
8 | | - input: infinicore.Tensor, |
9 | | - out=None |
10 | | -) -> infinicore.Tensor: |
11 | | - r"""Apply a causal softmax function. |
12 | | - """ |
| 8 | +def causal_softmax(input: Tensor, out=None) -> Tensor: |
| 9 | + r"""Apply a causal softmax function.""" |
13 | 10 |
|
14 | 11 | if out is None: |
15 | | - return infinicore.Tensor(_infinicore.causal_softmax(input._underlying)) |
| 12 | + return Tensor(_infinicore.causal_softmax(input._underlying)) |
16 | 13 |
|
17 | 14 | _infinicore.causal_softmax_(out._underlying, input._underlying) |
18 | 15 |
|
19 | 16 | return out |
20 | 17 |
|
21 | 18 |
|
22 | 19 | def rms_norm( |
23 | | - input: infinicore.Tensor, |
24 | | - normalized_shape: list[int], |
25 | | - weight: infinicore.Tensor, |
26 | | - eps: float = 1e-5, |
27 | | - out=None |
28 | | -) -> infinicore.Tensor: |
29 | | - r"""Apply Root Mean Square Layer Normalization. |
30 | | - """ |
31 | | - |
32 | | - assert normalized_shape == weight.shape, "normalized_shape does not match weight.shape." |
| 20 | + input: Tensor, |
| 21 | + normalized_shape: list[int], |
| 22 | + weight: Tensor, |
| 23 | + eps: float = 1e-5, |
| 24 | + *, |
| 25 | + out=None, |
| 26 | +) -> Tensor: |
| 27 | + r"""Apply Root Mean Square Layer Normalization.""" |
| 28 | + |
| 29 | + assert normalized_shape == weight.shape, ( |
| 30 | + "normalized_shape does not match weight.shape." |
| 31 | + ) |
33 | 32 |
|
34 | 33 | if out is None: |
35 | | - return infinicore.Tensor( |
36 | | - _infinicore.rms_norm(input._underlying, weight._underlying, eps) |
37 | | - ) |
| 34 | + return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps)) |
38 | 35 |
|
39 | 36 | _infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps) |
40 | 37 |
|
41 | 38 | return out |
42 | 39 |
|
43 | 40 |
|
44 | | -def silu(input: infinicore.Tensor, inplace: bool = False, out=None) -> infinicore.Tensor: |
45 | | - r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. |
46 | | - """ |
| 41 | +def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor: |
| 42 | + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" |
| 43 | + |
| 44 | + if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: |
| 45 | + return infinicore.ntops.torch.silu(input, inplace=inplace) |
47 | 46 |
|
48 | 47 | if inplace: |
49 | 48 | _infinicore.silu_(input._underlying, input._underlying) |
50 | 49 | return input |
51 | 50 |
|
52 | 51 | if out is None: |
53 | | - return infinicore.Tensor(_infinicore.silu(input._underlying)) |
| 52 | + return Tensor(_infinicore.silu(input._underlying)) |
54 | 53 |
|
55 | 54 | _infinicore.silu_(out._underlying, input._underlying) |
56 | 55 |
|
57 | 56 | return out |
58 | 57 |
|
59 | 58 |
|
60 | | -def swiglu(input: infinicore.Tensor, other: infinicore.Tensor, out=None): |
61 | | - r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise. |
62 | | - """ |
| 59 | +def swiglu(input: Tensor, other: Tensor, *, out=None): |
| 60 | + r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise.""" |
63 | 61 |
|
64 | 62 | if out is None: |
65 | | - return infinicore.Tensor(_infinicore.swiglu(input._underlying, other._underlying)) |
| 63 | + return Tensor(_infinicore.swiglu(input._underlying, other._underlying)) |
66 | 64 |
|
67 | 65 | _infinicore.swiglu_(out._underlying, input._underlying, other._underlying) |
68 | 66 |
|
|
0 commit comments