Skip to content

Commit aec7ab8

Browse files
committed
feat: 添加嵌入算子的正反向
Signed-off-by: YdrMaster <[email protected]>
1 parent 7d052e9 commit aec7ab8

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
2+
# `EmbeddingBackward`
3+
4+
`EmbeddingBackward`,即[**嵌入**算子](/infiniop/ops/embedding/README.md)的反向算子。用于训练大模型的词嵌入和加性位置嵌入。
5+
6+
`EmbeddingBackward` 算子支持 1 个或 2 个相同的步骤,根据“号码”从将输出的梯度叠加到嵌入表的梯度,其公式表述为:
7+
8+
$$ \begin{equation} d_{table1} = \alpha_1 \cdot dy[i_1] \end{equation} $$
9+
10+
$$ \begin{equation} d_{table2} = \alpha_2 \cdot dy[i_2] \end{equation} $$
11+
12+
- 通常 $α$ 为 1;
13+
- $ table2 $ 可以不使用,则公式 $(2)$ 不存在;
14+
15+
## 接口
16+
17+
### 计算
18+
19+
```c
20+
infiniStatus_t infiniopEmbeddingBackward(
21+
infiniopEmbeddingBackwardDescriptor_t desc,
22+
void *dtable1,
23+
void *dtable2,
24+
const void *dy,
25+
const void *i1,
26+
const void *i2,
27+
void *stream
28+
);
29+
```
30+
31+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
32+
33+
- `desc`:
34+
已使用 `infiniopEmbeddingBackwardDescriptor_t()` 初始化的算子描述符;
35+
- `dtable1`:
36+
第 1 个嵌入表的梯度;
37+
- `dtable2`:
38+
第 2 个嵌入表的梯度,不使用则为空;
39+
- `dy`:
40+
输出结果的梯度;
41+
- `i1`:
42+
第 1 个嵌入序号;
43+
- `i2`:
44+
第 2 个嵌入序号,不使用则为空;
45+
- `stream`:
46+
计算流/队列;
47+
48+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
49+
50+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_DEVICE`], [`INFINI_STATUS_EXECUTION_FAILED`].
51+
52+
### 创建算子描述
53+
54+
```c
55+
infiniStatus_t infiniopCreateEmbeddingBackwardDescriptor(
56+
infiniopHandle_t handle,
57+
infiniopEmbeddingBackwardDescriptor_t *desc_ptr,
58+
infiniopTensorDescriptor_t dtable1_desc,
59+
infiniopTensorDescriptor_t dtable2_desc,
60+
infiniopTensorDescriptor_t dy_desc,
61+
infiniopTensorDescriptor_t i1_desc,
62+
infiniopTensorDescriptor_t i2_desc,
63+
float alpha1,
64+
float alpha2,
65+
char dtable1_acc,
66+
char dtable2_acc
67+
);
68+
```
69+
70+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
71+
72+
- `handle`:
73+
`infiniopHandle_t` 类型的硬件控柄。详情请看:[`InfiniopHandle_t`]
74+
- `desc_ptr`:
75+
`infiniopCreateEmbeddingBackwardDescriptor` 指针,指向将被初始化的算子描述符地址;
76+
- `dtable1_desc` - $\{ dT | (N1, D) | (..., 1) \}$:
77+
算子输入 `table1` 的张量描述;
78+
- `dtable2_desc` - $\{ dT | (N2, D) | (..., 1) \}$:
79+
算子输入 `table2` 的张量描述;
80+
- `dy_desc` - $\{ dT | (N, D) | (..., 1) \}$:
81+
算子输出 `y` 的张量描述;
82+
- `i1_desc` - $\{ dI | (N) | (1) \}$:
83+
算子输入 `i1` 的张量描述;
84+
- `i2_desc` - $\{ dI | (N) | (1) \}$:
85+
算子输入 `i2` 的张量描述,为空表示不使用 $ table2 $, `alpha2` 必须同时为 0;
86+
- `alpha1` - float:
87+
第 1 项嵌入的缩放因子;
88+
- `alpha2` - float:
89+
第 2 项嵌入的缩放因子,取 0 表示不使用 $ table2 $,`i2_desc` 必须同时为空;
90+
- `dtable1_acc` - char:
91+
第 1 项嵌入是否叠加梯度,0 表示不叠加;
92+
- `dtable2_acc` - float:
93+
第 2 项嵌入是否叠加梯度,0 表示不叠加;
94+
95+
<div style="background-color: lightblue; padding: 1px;"> 参数限制:</div>
96+
97+
- $dT$: 任意代数类型;
98+
- $dT_i$: 任意整型;
99+
100+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
101+
102+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_BAD_DEVICE`].
103+
104+
### 销毁算子描述符
105+
106+
```c
107+
infiniStatus_t infiniopDestroyEmbeddingBackwardDescriptor(
108+
infiniopEmbeddingBackwardDescriptor_t desc
109+
);
110+
```
111+
112+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
113+
114+
- `desc`:
115+
输入。待销毁的算子描述符;
116+
117+
<div style="background-color: lightblue; padding: 1px;"> 返回值: </div>
118+
119+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_DEVICE`].
120+
121+
<!-- 链接 -->
122+
[`InfiniopHandle_t`]: /infiniop/handle/README.md
123+
124+
[`INFINI_STATUS_SUCCESS`]:/common/status/README.md#INFINI_STATUS_SUCCESS
125+
[`INFINI_STATUS_BAD_PARAM`]:/common/status/README.md#INFINI_STATUS_BAD_PARAM
126+
[`INFINI_STATUS_BAD_DEVICE`]:/common/status/README.md#INFINI_STATUS_BAD_DEVICE
127+
[`INFINI_STATUS_EXECUTION_FAILED`]:/common/status/README.md#INFINI_STATUS_EXECUTION_FAILED
128+
[`INFINI_STATUS_BAD_TENSOR_SHAPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE
129+
[`INFINI_STATUS_BAD_TENSOR_DTYPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE
130+
[`INFINI_STATUS_BAD_TENSOR_STRIDES`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES

infiniop/ops/embedding/README.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
2+
# `Embedding`
3+
4+
`Embedding`,即**嵌入**算子。用于对大模型进行词嵌入和加性位置嵌入。
5+
6+
`Embedding` 算子支持 1 个或 2 个相同的步骤,根据“号码”从嵌入表中获取嵌入向量并叠加到输出,其公式表述为:
7+
8+
$$ Y = \alpha_1 \cdot table_1[i_1] + \alpha_2 \cdot table_2[i_2] $$
9+
10+
- 通常 $α$ 为 1;
11+
- $ table2 $ 可以不使用,则公式变为 $ Y = \alpha_1 \cdot table_1[i_1] $;
12+
13+
## 接口
14+
15+
### 计算
16+
17+
```c
18+
infiniStatus_t infiniopEmbedding(
19+
infiniopEmbeddingDescriptor_t desc,
20+
void *y,
21+
const void *table1,
22+
const void *table2,
23+
const void *i1,
24+
const void *i2,
25+
void *stream
26+
);
27+
```
28+
29+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
30+
31+
- `desc`:
32+
已使用 `infiniopCreateEmbeddingDescriptor()` 初始化的算子描述符;
33+
- `y`:
34+
计算输出结果;
35+
- `table1`:
36+
第 1 个嵌入表;
37+
- `table2`:
38+
第 2 个嵌入表,不使用则为空;
39+
- `i1`:
40+
第 1 个嵌入序号;
41+
- `i2`:
42+
第 2 个嵌入序号,不使用则为空;
43+
- `stream`:
44+
计算流/队列;
45+
46+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
47+
48+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_DEVICE`], [`INFINI_STATUS_EXECUTION_FAILED`].
49+
50+
### 创建算子描述
51+
52+
```c
53+
infiniStatus_t infiniopCreateEmbeddingDescriptor(
54+
infiniopHandle_t handle,
55+
infiniopEmbeddingDescriptor_t *desc_ptr,
56+
infiniopTensorDescriptor_t y_desc,
57+
infiniopTensorDescriptor_t table1_desc,
58+
infiniopTensorDescriptor_t table2_desc,
59+
infiniopTensorDescriptor_t i1_desc,
60+
infiniopTensorDescriptor_t i2_desc,
61+
float alpha1,
62+
float alpha2
63+
);
64+
```
65+
66+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
67+
68+
- `handle`:
69+
`infiniopHandle_t` 类型的硬件控柄。详情请看:[`InfiniopHandle_t`]
70+
- `desc_ptr`:
71+
`infiniopCreateEmbeddingDescriptor` 指针,指向将被初始化的算子描述符地址;
72+
- `y_desc` - $\{ dT | (N, D) | (..., 1) \}$:
73+
算子输出 `y` 的张量描述;
74+
- `table1_desc` - $\{ dT | (N1, D) | (..., 1) \}$:
75+
算子输入 `table1` 的张量描述;
76+
- `table2_desc` - $\{ dT | (N2, D) | (..., 1) \}$:
77+
算子输入 `table2` 的张量描述;
78+
- `i1_desc` - $\{ dI | (N) | (1) \}$:
79+
算子输入 `i1` 的张量描述;
80+
- `i2_desc` - $\{ dI | (N) | (1) \}$:
81+
算子输入 `i2` 的张量描述,为空表示不使用 $ table2 $, `alpha2` 必须同时为 0;
82+
- `alpha1` - float:
83+
第 1 项嵌入的缩放因子;
84+
- `alpha2` - float:
85+
第 2 项嵌入的缩放因子,取 0 表示不使用 $ table2 $,`i2_desc` 必须同时为空;
86+
87+
参数限制:
88+
89+
- $dT$: 任意代数类型;
90+
- $dT_i$: 任意整型;
91+
92+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
93+
94+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_BAD_DEVICE`].
95+
96+
### 销毁算子描述符
97+
98+
```c
99+
infiniStatus_t infiniopDestroyEmbeddingDescriptor(
100+
infiniopEmbeddingDescriptor_t desc
101+
);
102+
```
103+
104+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
105+
106+
- `desc`:
107+
输入。待销毁的算子描述符;
108+
109+
<div style="background-color: lightblue; padding: 1px;"> 返回值: </div>
110+
111+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_DEVICE`].
112+
113+
<!-- 链接 -->
114+
[`InfiniopHandle_t`]: /infiniop/handle/README.md
115+
116+
[`INFINI_STATUS_SUCCESS`]:/common/status/README.md#INFINI_STATUS_SUCCESS
117+
[`INFINI_STATUS_BAD_PARAM`]:/common/status/README.md#INFINI_STATUS_BAD_PARAM
118+
[`INFINI_STATUS_BAD_DEVICE`]:/common/status/README.md#INFINI_STATUS_BAD_DEVICE
119+
[`INFINI_STATUS_EXECUTION_FAILED`]:/common/status/README.md#INFINI_STATUS_EXECUTION_FAILED
120+
[`INFINI_STATUS_BAD_TENSOR_SHAPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE
121+
[`INFINI_STATUS_BAD_TENSOR_DTYPE`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE
122+
[`INFINI_STATUS_BAD_TENSOR_STRIDES`]:/common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES

0 commit comments

Comments
 (0)