Skip to content

Commit

Permalink
Add a toy llama test for numerics
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Feb 24, 2025
1 parent 1fbac63 commit 2cefce1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
12 changes: 8 additions & 4 deletions sharktank/sharktank/models/llama/toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
parser.add_argument("-o", "--output", default="/tmp/toy_llama.irpa")


def main():
args = parser.parse_args()
torch.manual_seed(args.seed)

def generate(seed):
torch.manual_seed(seed)
dtype = torch.float16
block_seq_stride = 16
max_blocks = 8
Expand Down Expand Up @@ -56,6 +54,12 @@ def main():
config=config,
vocab_size=vocabulary_size,
)
return theta, config


def main():
args = parser.parse_args()
theta, config = generate(args.seed)

config_dict = config.hp.to_gguf_props()

Expand Down
50 changes: 50 additions & 0 deletions sharktank/tests/models/llama/test_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2025 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from sharktank.models.llama.llama import PagedLlamaModelV1
from sharktank.models.llama.toy_llama import generate
from sharktank.utils.create_cache import create_paged_kv_cache

import pytest
import torch

def test_llama():
theta, config = generate(12345)
model = PagedLlamaModelV1(theta=theta, config=config)

ids = [0, 208, 214, 29, 19, 86, 176, 120, 120, 80, 120, 208, 37, 157, 191, 137]
seq_len = len(ids)

blocks = (seq_len - 1) // config.block_seq_stride
blocks = blocks + 1
padded_length = blocks * config.block_seq_stride
padding = padded_length - seq_len
ids = ids + [0] * padding

ids = torch.asarray([ids], dtype=torch.int64)
block_ids = torch.asarray([[i for i in range(blocks)]]).to(torch.int64)

cache=create_paged_kv_cache(config)
cache_state = model.cache.allocate(
page_count=config.hp.context_length // config.block_seq_stride
)

logits = model.prefill(
tokens=ids,
attention_mask=None,
cache_state=cache_state,
seq_block_ids=block_ids
)

# Remove padding
ids = ids[:, :seq_len]
logits = logits[:, :seq_len, :]

ids = ids[0, 1:]
logits = logits[0, :-1]
cross_entropy = torch.nn.functional.cross_entropy(logits, ids)
assert pytest.approx(0.577, 1e-2) == cross_entropy

0 comments on commit 2cefce1

Please sign in to comment.