Thanks for open sourcing your code and model weights. As the title says, I am trying to use TK kernels with pre-linearized llama 3.1 8b model and unable to repro the numbers from the paper. I am using hazyresearch/lolcats-llama-3.1-8b-distill and hazyresearch/lolcats-llama-3.1-8b-ft-lora with the below script. I had to make a few changes to the load_model_from_checkpoint function as it expects the checkpoint dir to also contain config files. I also had to change the forward call in linear_window_attention_tk_gen.py as it was not using the TK kernels correctly. Below is the diff of my changes here. I am able to repro the paper number for PIQA dataset when attention_type: lolcats_llama_window_tk but this attention type doesn't use TK kernels so I changed this to attention_type: lolcats_llama_window_tk_gen but then the PIQA acc goes down to 50% which is basically random chance. I am not sure where the problem is.
python lm_eval_harness/eval_lm_harness.py \
--model_type lolcats_ckpt \
--batch_size 256 \
--model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--attn_mlp_checkpoint_path '/root/.cache/huggingface/hub/models--hazyresearch--lolcats-llama-3.1-8b-distill/snapshots/5b4f8f9c7a89b637c12ebd0df19831a40c17e68d/model.pt' \
--finetune_checkpoint_path '/root/.cache/huggingface/hub/models--hazyresearch--lolcats-llama-3.1-8b-ft-lora/snapshots/c8536c21acd0402fc5e1ae2825326a7f59bfb678/model.pt' \
--task piqa --num_shots 0 --no_cache --verbose
- tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(),
- self.y_true, self.k_state, self.kv_state,
- q_map, k_map, alphas, betas)
+ self.y_true, self.kv_state, self.k_state = tk_window_hedgehog_attention(
+ q.contiguous(), k.contiguous(), v.contiguous(),
+ q_map, k_map,
+ alphas, betas,
+ )
- past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx)
+ past_key_value.update_with_kv(self.kv_state, self.k_state, k, v, self.layer_idx)
Hi,
Thanks for open sourcing your code and model weights. As the title says, I am trying to use TK kernels with pre-linearized llama 3.1 8b model and unable to repro the numbers from the paper. I am using
hazyresearch/lolcats-llama-3.1-8b-distillandhazyresearch/lolcats-llama-3.1-8b-ft-lorawith the below script. I had to make a few changes to theload_model_from_checkpointfunction as it expects the checkpoint dir to also contain config files. I also had to change the forward call inlinear_window_attention_tk_gen.pyas it was not using the TK kernels correctly. Below is the diff of my changes here. I am able to repro the paper number for PIQA dataset whenattention_type: lolcats_llama_window_tkbut this attention type doesn't use TK kernels so I changed this toattention_type: lolcats_llama_window_tk_genbut then the PIQA acc goes down to 50% which is basically random chance. I am not sure where the problem is.