Skip to content

Commit 551e6d5

Browse files
authored
[dlinfer]rm rope reshape (#2984)
1 parent a6d4b4a commit 551e6d5

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@ def apply_rotary_pos_emb(
1515
) -> Tuple[Tensor, Tensor]:
1616
query_states = query_states.contiguous()
1717
key_states = key_states.contiguous()
18-
bs = query_states.shape[0]
1918
query_states_reshaped = query_states.unsqueeze(0)
2019
key_states_reshaped = key_states.unsqueeze(0)
21-
cos_reshaped = cos.reshape(1, bs, 1, -1)
22-
sin_reshaped = sin.reshape(1, bs, 1, -1)
2320
query_states_reshaped, key_states_reshaped = \
2421
ext_ops.apply_rotary_pos_emb(query_states_reshaped,
2522
key_states_reshaped,
26-
cos_reshaped, sin_reshaped,
23+
cos, sin,
2724
None, None)
2825
if q_embed is None:
2926
q_embed = query_states_reshaped.view(query_states.shape)

0 commit comments

Comments
 (0)