diff --git a/llama/model.py b/llama/model.py index f6e6d4f2a..44962740b 100755 --- a/llama/model.py +++ b/llama/model.py @@ -190,8 +190,8 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - self.cache_k = self.cache_k.index_copy(1, input_indexes, xk) - self.cache_v = self.cache_v.index_copy(1, input_indexes, xv) + self.cache_k.index_copy_(1, input_indexes, xk) + self.cache_v.index_copy_(1, input_indexes, xv) keys = self.cache_k[:, :] values = self.cache_v[:, :]