Skip to content

Commit c089a2a

Browse files
Add rotate_half implementation for fused_rope (#56401)
* add rotate_half in fused_rope * add position_ids in fused_rope * modified examples about fused_rope * add set_device in examples
1 parent be9cb94 commit c089a2a

File tree

11 files changed

+459
-114
lines changed

11 files changed

+459
-114
lines changed

paddle/phi/api/yaml/fused_backward.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
support_dygraph_mode : true
1818

1919
- backward_op : fused_rotary_position_embedding_grad
20-
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
21-
args : (Tensor sin, Tensor cos, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad)
20+
forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style) -> Tensor(out_q), Tensor(out_k), Tensor(out_v)
21+
args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style)
2222
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
23-
optional : sin, cos, out_k_grad, out_v_grad, k_grad, v_grad
23+
optional : sin, cos, position_ids, out_k_grad, out_v_grad, k_grad, v_grad
2424
infer_meta :
2525
func : FusedRopeGradInferMeta
2626
kernel :

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,11 @@
149149
optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index
150150

151151
- op : fused_rotary_position_embedding
152-
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos)
152+
args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true)
153153
output : Tensor(out_q), Tensor(out_k), Tensor(out_v)
154154
infer_meta :
155155
func : FusedRopeInferMeta
156-
optional : k,v,sin,cos, out_k, out_v
156+
optional : k, v, sin, cos, position_ids, out_k, out_v
157157
kernel :
158158
func : fused_rotary_position_embedding
159159
data_type : q

paddle/phi/infermeta/backward.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,9 +1219,11 @@ void IndexPutGradInferMeta(const MetaTensor& x,
12191219

12201220
void FusedRopeGradInferMeta(const MetaTensor& sin,
12211221
const MetaTensor& cos,
1222+
const MetaTensor& position_ids,
12221223
const MetaTensor& dout_q,
12231224
const MetaTensor& dout_k,
12241225
const MetaTensor& dout_v,
1226+
bool use_neox_rotary_style,
12251227
MetaTensor* dq,
12261228
MetaTensor* dk,
12271229
MetaTensor* dv) {

paddle/phi/infermeta/backward.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,11 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
186186

187187
void FusedRopeGradInferMeta(const MetaTensor& sin,
188188
const MetaTensor& cos,
189+
const MetaTensor& position_ids,
189190
const MetaTensor& dout_q,
190191
const MetaTensor& dout_k,
191192
const MetaTensor& dout_v,
193+
bool use_neox_rotary_style,
192194
MetaTensor* dq,
193195
MetaTensor* dk,
194196
MetaTensor* dv);

paddle/phi/infermeta/multiary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4041,6 +4041,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
40414041
const MetaTensor& v,
40424042
const MetaTensor& sin,
40434043
const MetaTensor& cos,
4044+
const MetaTensor& position_ids,
4045+
bool use_neox_rotary_style,
40444046
MetaTensor* out_q,
40454047
MetaTensor* out_k,
40464048
MetaTensor* out_v) {

paddle/phi/infermeta/multiary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,8 @@ void FusedRopeInferMeta(const MetaTensor& q,
807807
const MetaTensor& v,
808808
const MetaTensor& sin,
809809
const MetaTensor& cos,
810+
const MetaTensor& position_ids,
811+
bool use_neox_rotary_style,
810812
MetaTensor* out_q,
811813
MetaTensor* out_k,
812814
MetaTensor* out_v);

paddle/phi/kernels/fusion/gpu/fused_rope_grad_kernel.cu

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ template <typename T, typename Context>
2727
void FusedRopeGradKernel(const Context& dev_ctx,
2828
const paddle::optional<DenseTensor>& sin,
2929
const paddle::optional<DenseTensor>& cos,
30+
const paddle::optional<DenseTensor>& position_ids,
3031
const DenseTensor& dout_q,
3132
const paddle::optional<DenseTensor>& dout_k,
3233
const paddle::optional<DenseTensor>& dout_v,
34+
bool use_neox_rotary_style,
3335
DenseTensor* dq,
3436
DenseTensor* dk,
3537
DenseTensor* dv) {
@@ -58,6 +60,7 @@ void FusedRopeGradKernel(const Context& dev_ctx,
5860
phi::Array<T*, 3> outs_data;
5961
phi::Array<const T*, 3> ins_data;
6062
phi::Array<const T*, 2> sin_cos_data;
63+
const int64_t* position_ids_data = NULL;
6164

6265
ins_data[0] = dout_q.data<T>();
6366
outs_data[0] = dq->data<T>();
@@ -86,21 +89,42 @@ void FusedRopeGradKernel(const Context& dev_ctx,
8689
sin_cos_data[1] = cos->data<T>();
8790

8891
flag_sin_cos = true;
92+
93+
if (position_ids.get_ptr()) {
94+
position_ids_data = position_ids->data<int64_t>();
95+
}
8996
}
9097

9198
int sign = -1;
92-
VectorizedFusedRopeKernel<T, MPType, vec_size>
93-
<<<grid, block, 0, stream>>>(ins_data,
94-
sin_cos_data,
95-
flag_sin_cos,
96-
sign,
97-
batch_size,
98-
seq_len,
99-
num_heads,
100-
head_dim,
101-
outs_data,
102-
num_inputs,
103-
div_c);
99+
if (use_neox_rotary_style) {
100+
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
101+
<<<grid, block, 0, stream>>>(ins_data,
102+
sin_cos_data,
103+
position_ids_data,
104+
flag_sin_cos,
105+
sign,
106+
batch_size,
107+
seq_len,
108+
num_heads,
109+
head_dim,
110+
outs_data,
111+
num_inputs,
112+
div_c);
113+
} else {
114+
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
115+
<<<grid, block, 0, stream>>>(ins_data,
116+
sin_cos_data,
117+
position_ids_data,
118+
flag_sin_cos,
119+
sign,
120+
batch_size,
121+
seq_len,
122+
num_heads,
123+
head_dim,
124+
outs_data,
125+
num_inputs,
126+
div_c);
127+
}
104128
}
105129

106130
} // namespace fusion

paddle/phi/kernels/fusion/gpu/fused_rope_kernel.cu

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
3030
const paddle::optional<DenseTensor>& v,
3131
const paddle::optional<DenseTensor>& sin,
3232
const paddle::optional<DenseTensor>& cos,
33+
const paddle::optional<DenseTensor>& position_ids,
34+
bool use_neox_rotary_style,
3335
DenseTensor* out_q,
3436
DenseTensor* out_k,
3537
DenseTensor* out_v) {
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
5961
phi::Array<T*, 3> outs_data;
6062
phi::Array<const T*, 3> ins_data;
6163
phi::Array<const T*, 2> sin_cos_data;
64+
const int64_t* position_ids_data = NULL;
6265

6366
ins_data[0] = q.data<T>();
6467
outs_data[0] = out_q->data<T>();
@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
109112
"The batch_size and num_heads of sin and cos must be 1."));
110113
}
111114
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0;
112-
PADDLE_ENFORCE_EQ((sin_dims[dims_size - 1] == head_dim &&
113-
sin_dims[sin_seq_len_dim] == seq_len),
114-
true,
115-
phi::errors::InvalidArgument(
116-
"The seq_len and head_dim of sin and cos "
117-
"must be the same as those of q. But recieved sin's "
118-
"shape is {%s}, q's shape is {%s}.",
119-
sin_dims,
120-
q.dims()));
115+
116+
if (position_ids.get_ptr()) {
117+
PADDLE_ENFORCE_EQ(
118+
(sin_dims[dims_size - 1] == head_dim &&
119+
sin_dims[sin_seq_len_dim] >= seq_len),
120+
true,
121+
phi::errors::InvalidArgument(
122+
"The seq_len of sin and cos must be greater than or equal to "
123+
"this of q. The head_dim of sin and cos must be the same as this "
124+
"of q. But recieved sin's "
125+
"shape is {%s}, q's shape is {%s}.",
126+
sin_dims,
127+
q.dims()));
128+
129+
auto position_ids_dims = position_ids.get_ptr()->dims();
130+
PADDLE_ENFORCE_EQ(position_ids_dims.size(),
131+
2,
132+
phi::errors::InvalidArgument(
133+
"The dims of position_ids is expected to "
134+
"be 2, but recieved %d.",
135+
position_ids_dims.size()));
136+
137+
PADDLE_ENFORCE_EQ(
138+
(position_ids_dims[0] == batch_size &&
139+
position_ids_dims[1] == seq_len),
140+
true,
141+
phi::errors::InvalidArgument(
142+
"The batch_size and seq_len of position_ids must be the same as "
143+
"those of q. But recieved position_ids's "
144+
"shape is {%s}, q's shape is {%s}.",
145+
position_ids_dims,
146+
q.dims()));
147+
148+
position_ids_data = position_ids->data<int64_t>();
149+
} else {
150+
PADDLE_ENFORCE_EQ(
151+
(sin_dims[dims_size - 1] == head_dim &&
152+
sin_dims[sin_seq_len_dim] == seq_len),
153+
true,
154+
phi::errors::InvalidArgument(
155+
"The seq_len and head_dim of sin and cos "
156+
"must be the same as those of q. But recieved sin's "
157+
"shape is {%s}, q's shape is {%s}.",
158+
sin_dims,
159+
q.dims()));
160+
}
121161

122162
sin_cos_data[0] = sin->data<T>();
123163
sin_cos_data[1] = cos->data<T>();
@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
126166
}
127167

128168
int sign = 1;
129-
VectorizedFusedRopeKernel<T, MPType, vec_size>
130-
<<<grid, block, 0, stream>>>(ins_data,
131-
sin_cos_data,
132-
flag_sin_cos,
133-
sign,
134-
batch_size,
135-
seq_len,
136-
num_heads,
137-
head_dim,
138-
outs_data,
139-
num_inputs,
140-
div_c);
169+
if (use_neox_rotary_style) {
170+
VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
171+
<<<grid, block, 0, stream>>>(ins_data,
172+
sin_cos_data,
173+
position_ids_data,
174+
flag_sin_cos,
175+
sign,
176+
batch_size,
177+
seq_len,
178+
num_heads,
179+
head_dim,
180+
outs_data,
181+
num_inputs,
182+
div_c);
183+
} else {
184+
VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
185+
<<<grid, block, 0, stream>>>(ins_data,
186+
sin_cos_data,
187+
position_ids_data,
188+
flag_sin_cos,
189+
sign,
190+
batch_size,
191+
seq_len,
192+
num_heads,
193+
head_dim,
194+
outs_data,
195+
num_inputs,
196+
div_c);
197+
}
141198
}
142199
} // namespace fusion
143200
} // namespace phi

0 commit comments

Comments
 (0)