diff --git a/classifier.py b/classifier.py index d660aed..67751db 100644 --- a/classifier.py +++ b/classifier.py @@ -6,9 +6,10 @@ from config import LlamaConfig from llama import load_pretrained from tokenizer import Tokenizer +from typing import List class LlamaZeroShotClassifier(torch.nn.Module): - def __init__(self, config: LlamaConfig, tokenizer: Tokenizer, label_names: list[str]): + def __init__(self, config: LlamaConfig, tokenizer: Tokenizer, label_names: List[str]): super(LlamaZeroShotClassifier, self).__init__() self.num_labels = config.num_labels self.llama = load_pretrained(config.pretrained_model_path) diff --git a/llama.py b/llama.py index f101fef..39f75a6 100644 --- a/llama.py +++ b/llama.py @@ -191,9 +191,9 @@ def forward(self, x): 1) layer normalization of the input (via Root Mean Square layer normalization) 2) self-attention on the layer-normalized input 3) a residual connection (i.e., add the input to the output of the self-attention) - 3) layer normalization on the output of the self-attention - 4) a feed-forward network on the layer-normalized output of the self-attention - 5) add a residual connection from the unnormalized self-attention output to the + 4) layer normalization on the output of the self-attention + 5) a feed-forward network on the layer-normalized output of the self-attention + 6) add a residual connection from the unnormalized self-attention output to the output of the feed-forward network ''' # todo diff --git a/rope_test.py b/rope_test.py index 24dc031..07af55c 100644 --- a/rope_test.py +++ b/rope_test.py @@ -2,6 +2,7 @@ import numpy as np from rope import apply_rotary_emb +from typing import Tuple seed = 0 @@ -17,7 +18,7 @@ def construct_key() -> torch.Tensor: ''' return 3 * torch.ones([1, 2, 2, 4]) -def test_apply_rotary_emb() -> tuple[torch.Tensor, torch.Tensor]: +def test_apply_rotary_emb() -> Tuple[torch.Tensor, torch.Tensor]: rng = np.random.default_rng(seed) torch.manual_seed(seed) model = torch.nn.Linear(3, 2, bias=False) diff --git a/structure.md b/structure.md index 8d126c5..1ee78c7 100644 --- a/structure.md +++ b/structure.md @@ -37,9 +37,9 @@ The desired outputs are ### To be implemented Components that require your implementations are comment with ```#todo```. The detailed instructions can be found in their corresponding code blocks -* ```llama.Attention.forward``` +* ```llama.Attention.compute_query_key_value_scores``` * ```llama.RMSNorm.norm``` -* ```llama.Llama.forward``` +* ```llama.LlamaLayer.forward``` * ```llama.Llama.generate``` * ```rope.apply_rotary_emb``` (this one may be tricky! you can use `rope_test.py` to test your implementation) * ```optimizer.AdamW.step```