diff --git a/lit_llama/utils.py b/lit_llama/utils.py index a09ada20..bc224576 100644 --- a/lit_llama/utils.py +++ b/lit_llama/utils.py @@ -28,7 +28,7 @@ def llama_model_lookup(checkpoint: dict) -> str: Checks the width of the lm_head.weight matrix, as these uniquely identify the model. """ - embedding_size = checkpoint['transformer.wte.weight'].shape[1] + embedding_size = checkpoint['lm_head.weight'].shape[1] return llama_model_sizes[embedding_size]