Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Feb 25, 2025
1 parent d664d88 commit 9a02c48
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 2 deletions.
1 change: 0 additions & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def forward(
):
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)

print("c", h[0, :, 0])
x = self.attn_norm(h)
bs, batch_seq_len, _ = x.shape

Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ def __init__(
self.dtype = dtype

def forward(self, input: torch.Tensor):
print("embedding", self.weight[0, 0])
return ops.embedding_lookup(input, self.weight, dtype=self.dtype)
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/grok/toy_grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


def generate(seed):
torch.manual_seed(seed)
dtype = torch.float16
block_seq_stride = 16
max_blocks = 8
Expand Down Expand Up @@ -52,6 +51,7 @@ def generate(seed):
attention_dtype=dtype,
)

torch.manual_seed(seed)
theta = make_random_grok_theta(
config=config,
vocab_size=vocabulary_size,
Expand Down

0 comments on commit 9a02c48

Please sign in to comment.