Skip to content

Commit 1bc6a87

Browse files
committed
添加rope
1 parent 247636a commit 1bc6a87

File tree

3 files changed

+32
-80
lines changed

3 files changed

+32
-80
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ itertools = "0.13"
4141
env_logger = "0.11"
4242
build-script-cfg = "0.0"
4343

44-
operators = { git = "https://github.com/onenewcode/operators-rs", rev = "f4a83f7", default-features = false }
44+
operators = { path = "/home/ztf/operators-rs/operators"}
4545

4646
search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "f69b160" }
4747
search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "e8362c3" }

models/llama/common/src/compute.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ where
480480
cos_layout: cos.layout(),
481481
cos_base: cos.base(),
482482
theta: self.meta.theta,
483+
rope_type: rope::RopeType::Rope,
483484
},
484485
workspace,
485486
queue_alloc,

models/minicpm3/common/src/compute.rs

Lines changed: 30 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ use operators::{
99
add::{self, Add},
1010
all_reduce::{self, AllReduce, ReduceOp},
1111
attention::{self, Attention},
12-
attention_kv_cached::{AttnKVCached},
12+
attention_kv_cached::AttnKVCached,
1313
fuesd_softmax::AttnMask,
1414
mat_mul::{self, MatMul},
1515
rearrange::{self, Rearrange},
1616
rms_norm::{self, RmsNorm},
17-
rope::{self, Rope, SinCosTable},
17+
rope::{self, Rope, Seq, SinCosTable},
1818
swiglu::{self, Swiglu},
1919
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
2020
};
@@ -151,7 +151,6 @@ where
151151
dkv_lora,
152152
dv,
153153
dt_embd,
154-
155154
..
156155
} = self.meta;
157156
// llama.cpp 定义死
@@ -175,8 +174,22 @@ where
175174
let attn = tensor(&[nt, nh, dv]);
176175
let (buf, workspace) = workspace.split_at_mut(*attn.get());
177176
let mut attn = attn.map(|_| buf);
178-
179177
let queue = queue_alloc.queue();
178+
179+
let sin = sin_cos.clone().index(0, 0);
180+
let cos = sin_cos.index(0, 1);
181+
182+
let pos = Tensor::new(self.dt_pos, &[nt]).map(|_| {
183+
Ops::Rope::build_pos(
184+
self.dt_pos,
185+
nt,
186+
requests.iter().map(|req| Seq {
187+
pos: req.pos,
188+
len: req.seq_len,
189+
}),
190+
queue_alloc,
191+
)
192+
});
180193
// 缩放
181194
let inplace = unsafe { x.map_slice_static() };
182195
self.scale(&mut x, &inplace, scale_emb, workspace, queue_alloc)?;
@@ -232,95 +245,26 @@ where
232245

233246
split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);
234247

235-
/// longrope
236-
pub fn longrope(
237-
embd: &mut [f32],
238-
pos: f32,
239-
theta: f32,
240-
long_factor: &[f32],
241-
short_factor: &[f32],
242-
max_pos: f32,
243-
origin_max_pos: f32,
244-
) {
245-
use std::slice::from_raw_parts_mut;
246-
// 计算 scaling_factor
247-
let scaling_factor =
248-
1.0 + ((max_pos / origin_max_pos).ln() / origin_max_pos.ln()).sqrt();
249-
let factor = if pos > origin_max_pos {
250-
long_factor
251-
} else {
252-
short_factor
253-
};
254-
let dh = embd.len() / 2;
255-
let embd =
256-
unsafe { from_raw_parts_mut(embd.as_mut_ptr().cast::<[f32; 2]>(), dh) };
257-
for (i, pair) in embd.iter_mut().enumerate() {
258-
let theta = theta.powf(-(i as f32 / dh as f32));
259-
let freq = pos * theta * factor.get(i).unwrap().recip();
260-
let (sin, cos) = freq.sin_cos();
261-
let (sin, cos) = (sin * scaling_factor, cos * scaling_factor);
262-
let [a, b] = *pair;
263-
*pair = [a * cos - b * sin, a * sin + b * cos];
264-
}
265-
}
266-
let cast = |t: *const f32| -> &'static [f32] {
267-
unsafe { std::slice::from_raw_parts(t, dh / 2) }
268-
};
269-
let [long_factor, short_factor] = self.weights.factor(queue);
270-
let long_factor = cast(long_factor.base().cast());
271-
let short_factor = cast(short_factor.base().cast());
272-
273248
// k [1, 3840]
274249
let k = tensor(&[nt, nh, dk]);
275250
let (buf, workspace) = workspace.split_at_mut(*k.get());
276251
let k = k.map(|_| buf);
277252

278253
split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
279254

280-
let pos = requests.last().unwrap().pos as f32;
281-
let (max_pos, origin_max_pos) = (100f32, 100f32);
282-
283-
// q 嵌入
284-
(0..nh).for_each(|i| {
285-
let tmp_q = unsafe {
286-
std::slice::from_raw_parts_mut(
287-
q_rope.base_mut().cast::<f32>().add(i * 32),
288-
32,
289-
)
290-
};
291-
longrope(
292-
tmp_q,
293-
pos,
294-
self.meta.theta,
295-
long_factor,
296-
short_factor,
297-
max_pos,
298-
origin_max_pos,
299-
);
300-
});
301-
// k 嵌入
302-
303-
let k_rope_1 =
304-
unsafe { std::slice::from_raw_parts_mut(k_rope.base_mut().cast::<f32>(), 32) };
305-
longrope(
306-
k_rope_1,
307-
pos,
308-
self.meta.theta,
309-
long_factor,
310-
short_factor,
311-
max_pos,
312-
origin_max_pos,
313-
);
314-
315-
// 经行广播和拷贝
316-
let k_rope = k_rope.tile(1, &[1, dh]).broadcast(1, nh);
255+
self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
256+
let mut k_rope = k_rope.tile(1, &[1, dh]);
257+
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
258+
let k_rope = k_rope.broadcast(1, nh);
317259
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
318260
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;
319261

262+
let pos = requests.last().unwrap().pos as f32;
320263
let mut q = q3.transpose(&[1, 0]);
321264
let k = k.map_slice().transpose(&[1, 0]);
322265
let v = v.map_slice_mut().transpose(&[1, 0]);
323266
let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
267+
let pos = requests.last().unwrap().pos as f32;
324268
self.attnention(
325269
&mut q,
326270
&k,
@@ -490,6 +434,7 @@ where
490434
Cos: Deref<Target = [ByteOf<Ops::Hardware>]>,
491435
QA: QueueAlloc<Hardware = Ops::Hardware>,
492436
{
437+
let [long, short] = self.weights.factor(queue_alloc.queue());
493438
self.rope.launch(
494439
&rope::Args {
495440
t_layout: t.layout(),
@@ -501,6 +446,12 @@ where
501446
cos_layout: cos.layout(),
502447
cos_base: cos.base(),
503448
theta: self.meta.theta,
449+
rope_type: rope::RopeType::Long {
450+
long: long.base(),
451+
short: short.base(),
452+
max_pos: 100,
453+
origin_pos: 100,
454+
},
504455
},
505456
workspace,
506457
queue_alloc,

0 commit comments

Comments
 (0)