diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 6078de7102ca..62bf1954d1f5 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -82,7 +82,7 @@ def __init__( torch_weight, device=device, dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, @@ -93,7 +93,7 @@ def __init__( torch_weight, device=device, dtype=weight_dtype, - layout=ttnn.ROW_MAJOR_LAYOUT, + layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) @@ -125,6 +125,11 @@ def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> else: assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" + if x.shape[-1] % weight.shape[-1] == 0: + # Reshape weight only if x's last dimension is divisible by weight's last dimension, + # to avoid padding errors in RMSNorm when dimensions are not aligned + weight = ttnn.reshape(weight, [1, 1, 1, -1]) + x = norm( x, epsilon=self.eps, diff --git a/models/experimental/gemma3_1b/tests/test_attention.py b/models/experimental/gemma3_1b/tests/test_attention.py deleted file mode 100644 index d488150ccd27..000000000000 --- a/models/experimental/gemma3_1b/tests/test_attention.py +++ /dev/null @@ -1,277 +0,0 @@ -""" Test for Gemma-3-1b-it Attention """ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 -import os - -import pytest -import torch -from loguru import logger - -import ttnn -from models.experimental.gemma3_1b.tt.attention import Attention -from models.tt_transformers.tt.common import PagedAttentionConfig, precompute_freqs -from models.tt_transformers.tt.rope import RotarySetup -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - False, - ), - ids=( - "paged_attention", - "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (256,), # For decode-only unit test, there's no need to run with large sequence lengths -) -def test_attention_inference( - max_seq_len, - batch_size, - paged_attention, - page_params, - mesh_device, - reset_seeds, -): - dtype = ttnn.bfloat16 - pcc = 0.99 - - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) - model_args.n_layers = 1 # For the unit test, just run a single layer - - state_dict = model_args.load_state_dict() - - first_layer_prefix = model_args.get_state_dict_prefix("Attention", 0) + "." - # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model = model_args.reference_attention() - reference_model.load_state_dict(partial_state_dict) - - seq_len = 1 - - generation_start_pos = 0 - generation_length = 10 - all_tests_pass = True - - # Setup RoPE transformation matrices - rope_setup = RotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.rope_scaling_factor, - model_args.orig_context_len, - ) - - transformation_mats = rope_setup.get_both_trans_mats() - - page_table_tt = None - paged_attention_config = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - tt_model = Attention( - mesh_device, - state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - layer_num=0, - dtype=dtype, - transformation_mats=transformation_mats, - configuration=model_args, - paged_attention_config=paged_attention_config, - ) - - cos, sin = precompute_freqs( - model_args.head_dim, - model_args.max_seq_len * 2, - model_args.rope_theta, - model_args.rope_scaling_factor, - model_args.orig_context_len, - ) - freqs_cis = torch.complex(cos, sin) - - # Initial positions - current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - for i in range(generation_length): - # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 - pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 - - tt_attention_input = pt_attention_input.clone() - - attention_input = model_args.prepare_residual_tensor_decode( - tt_attention_input, - model_args.model_config["SHARDED_ATTN_INPUT_MEMCFG"], - force_replicated=False if model_args.is_galaxy else True, - ) - - # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) - - tt_out = tt_model( - attention_input, - current_pos_tensor, - rot_mats=rot_mats, - mode="decode", - page_table=page_table_tt, - ) - # multi-device attention module returns replicated output - tt_out = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), - ) - tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) - - # In this test all users have the same position (if using batch > 1) - freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) - - reference_output = reference_model(pt_attention_input, current_pos[0], freqs_cis_i, mask=None) - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc) - - logger.info(comp_allclose(reference_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info(f"[pos={current_pos[0]}] Attention Passed!") - else: - logger.warning(f"[pos={current_pos[0]}] Attention Failed!") - all_tests_pass = False - - # Increment position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - check_kv_cache = True - if check_kv_cache: - # PyTorch output -------------------------------------------------------------------- - pytorch_layer_present = [ - reference_model.cache_k.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] - reference_model.cache_v.clone().permute(0, 2, 1, 3), # [batch_size, n_kv_heads, seq, head_dim] - ] - # TT hardware execution ------------------------------------------------------------- - if paged_attention: - tt_layer_present = [ - ( - ttnn.to_torch( - cache, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, - dims=(1, 3) if model_args.is_galaxy else (0, 1), - mesh_shape=model_args.cluster_shape, - ), - )[reverse_permutation][:, : model_args.n_kv_heads, :, : model_args.head_dim] - .reshape( - model_args.max_batch_size, - paged_attention_config.max_num_blocks // model_args.max_batch_size, - model_args.n_kv_heads, - paged_attention_config.block_size, - model_args.head_dim, - ) - .transpose(1, 2) - .reshape(model_args.max_batch_size, model_args.n_kv_heads, -1, model_args.head_dim)[ - :batch_size, ... - ] - ) - for cache in tt_model.layer_past - ] - else: - tt_layer_present = [ - ttnn.to_torch( - cache, - mesh_composer=ttnn.ConcatMesh2dToTensor( - mesh_device, - dims=(1, 0) if model_args.is_galaxy else (0, 1), - mesh_shape=model_args.cluster_shape, - ), - )[:batch_size, :, :, :] - for cache in tt_model.layer_past - ] - for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): - cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) - cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] - cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] - does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) - logger.info(f"{label} cache output: {output_pcc}") - if does_pass: - logger.info(f"{label} cache Passed!") - else: - logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") - all_tests_pass = False - - if all_tests_pass: - logger.info("Attention output Passed!") - else: - logger.warning("Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_decoder.py b/models/experimental/gemma3_1b/tests/test_decoder.py deleted file mode 100644 index 4bce0d46159f..000000000000 --- a/models/experimental/gemma3_1b/tests/test_decoder.py +++ /dev/null @@ -1,198 +0,0 @@ -""" Test for Gemma-3-1b-it Decoder """ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.model_config import ModelArgs -from models.experimental.gemma3_1b.tt.decoder import TransformerBlock -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) -from models.utility_functions import skip_for_grayskull -from models.tt_transformers.tt.common import PagedAttentionConfig -from models.tt_transformers.tt.rope import RotarySetup - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "paged_attention", - (True, False), - ids=("paged_attention", "default_attention"), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (256,), # For decode-only unit test, there's no need to run with large sequence lengths -) -def test_decoder_inference( - max_seq_len, - batch_size, - paged_attention, - page_params, - mesh_device, - reset_seeds, -): - dtype = ttnn.bfloat16 - - model_args = ModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) - model_args.n_layers = 1 - - state_dict = model_args.load_state_dict() - - reference_model = model_args.reference_decoder() - - generation_start_pos = 0 - generation_length = 10 - all_tests_pass = True - - rope_setup = RotarySetup( - mesh_device, - batch_size, - model_args.head_dim, - model_args.max_seq_len, - model_args.rope_theta, - model_args.rope_scaling_factor, - model_args.orig_context_len, - ) - transformation_mats = rope_setup.get_both_trans_mats() - - # Prepare page table for paged attention - page_table_tt = None - paged_attention_config = None - - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Initialize TT model - tt_model = TransformerBlock( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - layer_num=0, - weight_cache_path=model_args.weight_cache_path(dtype), - transformation_mats=transformation_mats, - paged_attention_config=paged_attention_config, - ) - - seqlen = 1 - - # Initial positions - current_pos = torch.tensor([generation_start_pos for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - for i in range(generation_length): - pt_decode_input = (torch.rand(batch_size, seqlen, model_args.dim) * 2) - 1 - logger.info(f"[Decoder] Generating token {i}") - - tt_decode_input = pt_decode_input.clone() - - decode_input = model_args.prepare_residual_tensor_decode( - tt_decode_input, - # ttnn.DRAM_MEMORY_CONFIG, - model_args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # Get cos/sin matrices for the current position of each user - rot_mat_global = rope_setup.get_rot_mats(current_pos) - rot_mat_local = rope_setup.get_rot_mats(current_pos) - - # Run TT model - tt_out = tt_model( - decode_input, - current_pos_tensor, - rot_mats=[rot_mat_global, rot_mat_local], - mode="decode", - page_table=page_table_tt, - ) - tt_out = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), - ) - - tt_output_torch = tt_out[:, 0:1, : model_args.max_batch_size, : model_args.dim].view(-1, 1, model_args.dim) - - # Reference model - ref_output = reference_model(pt_decode_input, current_pos[0], None, mask=None) - - passing, pcc_message = comp_pcc(ref_output, tt_output_torch) - - logger.info(comp_allclose(ref_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Decoder Block Passed!") - else: - logger.warning("Decoder Block Failed!") - all_tests_pass = False - - # Increment position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - if all_tests_pass: - logger.info(f"All {generation_length} decode iterations Passed!") - else: - logger.warning("One or more iterations of decode Failed!") - assert all_tests_pass, f"PCC value is lower than {0.99} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_mlp.py b/models/experimental/gemma3_1b/tests/test_mlp.py deleted file mode 100644 index 7e25632dfaf8..000000000000 --- a/models/experimental/gemma3_1b/tests/test_mlp.py +++ /dev/null @@ -1,98 +0,0 @@ -""" Test for Gemma-3-1b-it MLP """ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger - -import torch -import pytest -import os -import ttnn - -from models.experimental.gemma3_1b.tt.mlp import MLP -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("device"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (1152,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -def test_mlp_inference(seq_len, batch_size, reset_seeds, device): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs(device, max_batch_size=batch_size, max_seq_len=128) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - # # Ref model needs partial state dict, but our models use full state dict keys as cached weight names - first_layer_prefix = "layers.0.feed_forward" - partial_state_dict = { - k[len(first_layer_prefix) + 1 :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model = tt_model_args.reference_mlp() # Gemma3 MLP - reference_model.load_state_dict(partial_state_dict) - - tt_model = MLP( - mesh_device=device, - args=tt_model_args, - state_dict=state_dict, - weight_cache_path=tt_model_args.weight_cache_path(dtype), - layer_num=0, - dtype=dtype, - model_config=tt_model_args.get_model_config(), - state_dict_prefix=first_layer_prefix, - ) - torch_input = torch.randn(1, 1, seq_len) - reference_output = reference_model(torch_input) - - tt_input = ttnn.from_torch( - torch_input, - device=device, - mesh_mapper=ttnn.ShardTensor2dMesh( - device, dims=(None, 3) if tt_model_args.is_galaxy else (None, None), mesh_shape=tt_model_args.cluster_shape - ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` - dtype=dtype, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - ) - - logger.info("Run MLP") - tt_output = tt_model(tt_input, mode) - - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor(device, dims=(1, 3), mesh_shape=tt_model_args.cluster_shape), - ) - - pcc_required = 0.99 - passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) - - logger.info(comp_allclose(reference_output, tt_output_torch[0])) - logger.info(f"PCC: {pcc_message}") - if passing: - logger.info("MLP Passed!") - else: - logger.warning("MLP Failed!") - - assert passing, f"MLP output does not meet PCC requirement {pcc_required}: {pcc_message}." diff --git a/models/experimental/gemma3_1b/tests/test_model.py b/models/experimental/gemma3_1b/tests/test_model.py deleted file mode 100644 index f0c3730eda57..000000000000 --- a/models/experimental/gemma3_1b/tests/test_model.py +++ /dev/null @@ -1,334 +0,0 @@ -""" Test for Gemma-3-1b-it End-to-End Model""" -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -from loguru import logger -import os -import ttnn -from models.tt_transformers.tt.common import ( - encode_prompt_hf, - sample_host, - PagedAttentionConfig, -) -from models.tt_transformers.tt.model_config import DecodersPrecision - -from models.experimental.gemma3_1b.tt.model import Gemma3Transformer -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) -from models.utility_functions import skip_for_grayskull, skip_for_blackhole -from models.tt_transformers.tt.model_config import HfModelWrapper -from models.tt_transformers.tt.model_config import ModelArgs - -import re - - -def parse_chat_output(text): - pattern = r"<\|(?Puser|assistant)\|>\s*(?P.*?)(?=<\|(?:user|assistant|end)\|>|$)" - matches = re.finditer(pattern, text, re.DOTALL) - return [(match.group("role"), match.group("message").strip()) for match in matches] - - -def display_chat(logger, conversation): - for role, message in conversation: - if role == "user": - logger.info(f"👤 User: {message}") - elif role == "assistant": - logger.info(f"🤖 Assistant: {message}") - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@skip_for_blackhole("Failing on DRAM harvested P100a, see #21419") -@pytest.mark.timeout(1800) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "weights, layers", - [ - ("instruct", None), - ], - ids=["full"], -) -@pytest.mark.parametrize( - "paged_attention", - ( - True, - False, - ), - ids=( - "paged_attention", - "default_attention", - ), -) -@pytest.mark.parametrize( - "page_params", - [{"page_block_size": 32, "page_max_num_blocks": 1024}], -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -@pytest.mark.parametrize( - "max_seq_len", - (256,), # For decode-only unit test, there's no need to run with large sequence lengths -) -@pytest.mark.parametrize( - "optimizations", - [ - lambda model_args: DecodersPrecision.accuracy(model_args.n_layers, model_args.model_name), # poem - ], - ids=["accuracy"], -) -@pytest.mark.parametrize( - "mesh_device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("MESH_DEVICE"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -def test_model_inference( - weights, - layers, - max_seq_len, - batch_size, - paged_attention, - page_params, - optimizations, - mesh_device, - reset_seeds, - request, - parse_chat=True, -): - run_ref_pt = True # Flag to run reference PyTorch model and compare PCC - - dtype = ttnn.bfloat16 - - test_id = request.node.callspec.id - - mode_accuracy = "accuracy" in test_id - instruct = True if weights == "instruct" else False - - model_args = ModelArgs( - mesh_device=mesh_device, - instruct=instruct, - optimizations=optimizations, - max_seq_len=max_seq_len, - max_batch_size=batch_size, - ) - - # Expected PCC for the model - pcc = 0.86 - - # Number of decode iterations to run for the model - iterations = 20 - - if layers is not None: - model_args.n_layers = layers - state_dict = model_args.load_state_dict() - state_dict_prefix = model_args.get_state_dict_prefix("", None) - - prompts = ["Consider the sequence of prime numbers: 2, 3, 5, 7, count till 100"] * model_args.max_batch_size - - tokenizer = model_args.tokenizer - if instruct: - encoded_prompts = encode_prompt_hf(tokenizer=tokenizer, prompt_text=prompts[0]) - else: - encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] - - if run_ref_pt: - reference_transformer_model = model_args.reference_transformer(wrap=False) - reference_model = HfModelWrapper(reference_transformer_model, model_args.head_dim) - logger.info("Finished loading reference model.") - - # Embedding on host - embd = model_args.reference_embedding(reference_transformer_model) - else: - # Embedding on host - embd = model_args.reference_embedding() - generation_start_pos = 0 - generation_length = iterations - - page_table_tt = None - paged_attention_config = None - - # Prepare page table for paged attention - if paged_attention: - paged_attention_config = PagedAttentionConfig( - block_size=page_params["page_block_size"], - max_num_blocks=page_params["page_max_num_blocks"], - ) - # Implied shuffling of blocks - permutation = torch.randperm(paged_attention_config.max_num_blocks) - # Page table which maps virtual blocks to physical - reverse_permutation = torch.argsort(permutation) - page_table = reverse_permutation.reshape( - model_args.max_batch_size, paged_attention_config.max_num_blocks // model_args.max_batch_size - ) - page_table_tt = ttnn.from_torch( - page_table, - device=mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, -2) if batch_size > 1 else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Load TTNN model - tt_model = Gemma3Transformer( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=model_args.weight_cache_path(dtype), - paged_attention_config=paged_attention_config, - ) - - logger.info("Model and caches loaded.") - - if run_ref_pt: - all_tests_pass = True - - seqlen = 1 # Generating one token per user at a time - batch = model_args.max_batch_size - - # Select the first token from the prompts for initial decoding - encoded_prompts_tensor = torch.tensor(encoded_prompts).unsqueeze(0) # [:,0] - pt_decode_input = embd(encoded_prompts_tensor[:, 0]).view(batch, seqlen, -1) - tt_decode_input = pt_decode_input - - # Keep track of generated outputs to print out later - all_outputs = [] - if run_ref_pt: - all_outputs_ref = [] - - # Initial positions - current_pos = torch.tensor([generation_start_pos for _ in range(batch)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - for i in range(generation_length): - logger.info(f"[Model] Generating token {i}") - - decode_input = model_args.prepare_residual_tensor_decode( - tt_decode_input, - model_args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # Get cos/sin matrices for the current position of each user - # rot_mats = tt_model.rope_setup.get_rot_mats(current_pos ) # TODO Fix for sliding window attention #TODO Fix for Gemma3 4B - # rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) #TODO Fix for sliding window attention #TODO Fix for Gemma3 4B - rot_mats_global = tt_model.rope_setup.get_rot_mats(current_pos) # default - rot_mats_local = tt_model.rope_setup_local.get_rot_mats(current_pos) # default - rot_mats = [rot_mats_global, rot_mats_local] - - # Run TT model - tt_out = tt_model( - decode_input, - current_pos_tensor, - rot_mats=rot_mats, # should contain both for slidig window and without it #TODO Fix for Gemma3 4B - mode="decode", - page_table=page_table_tt, - ) - - # Convert ttnn tensor to torch tensor - mesh_composer = ttnn.ConcatMesh2dToTensor( - mesh_device, dims=(3, 1) if model_args.is_galaxy else (1, -1), mesh_shape=model_args.cluster_shape - ) - tt_output_torch = ( - ttnn.to_torch(tt_out, mesh_composer=mesh_composer) - .permute(2, 1, 0, 3) - .squeeze(2)[: model_args.max_batch_size, 0:1, : model_args.vocab_size] - ) - - ttnn.deallocate(tt_out) - - if run_ref_pt: # Run reference model - # In this test all users have the same position - ref_output = reference_model(pt_decode_input, current_pos[0]) - - # Increment position - current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch)]) - current_pos_tensor = ttnn.from_torch( - current_pos, - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ), - ) - - # Append the generated token to the list of outputs /prefill - if i in range(len(encoded_prompts)): - # While in "prefill" mode, use the prompt tokens as the output - all_outputs.append(encoded_prompts[i]) # Update list of TT outputs - if run_ref_pt: - all_outputs_ref.append(encoded_prompts[i]) # Update list of ref outputs - - tt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) - if run_ref_pt: - pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) - else: - # Greedy decode (temperature = 0) the generated token and save it to print out later - if run_ref_pt: - # Sample from reference model first - _, pt_out_tok = sample_host(ref_output, temperature=0, top_p=0.8) - pt_decode_input = embd(pt_out_tok) - all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) - - # Use the same token for TT model (teacher forcing) - tt_decode_input = pt_decode_input - all_outputs.append(pt_out_tok.squeeze(1).tolist()[0]) - else: - # If not running reference model, sample from TT model directly - _, tt_out_tok = sample_host(tt_output_torch, temperature=0, top_p=0.8) - tt_decode_input = embd(tt_out_tok) - all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) - - # Measure PCC if also running reference model - if run_ref_pt: - passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) - # Decode the output tokens back to text - decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs] - logger.info(f"TTNN Decoded Outputs: {decoded_texts}") - decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in all_outputs_ref] - logger.info(f"Torch Decoded Outputs: {decoded_texts}") - - logger.info(comp_allclose(ref_output, tt_output_torch)) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("Model Passed!") - else: - logger.warning("Model Failed!") - if not passing: - all_tests_pass = False - - if parse_chat: - conversation = parse_chat_output(tokenizer.decode(all_outputs).replace("\n", "\\n")) - display_chat(logger, conversation) - - if run_ref_pt: - if all_tests_pass: - logger.info(f"All {generation_length} decode iterations Passed!") - else: - logger.warning("One or more iterations of decode had bad PCC") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/experimental/gemma3_1b/tests/test_rmsnorm.py b/models/experimental/gemma3_1b/tests/test_rmsnorm.py deleted file mode 100644 index 54168b9c7b94..000000000000 --- a/models/experimental/gemma3_1b/tests/test_rmsnorm.py +++ /dev/null @@ -1,116 +0,0 @@ -""" Test for Gemma-3-1b-it RMSNorm """ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -from loguru import logger - -import torch -import pytest -import os - -import ttnn -from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.distributed_norm import DistributedNorm - - -from models.utility_functions import comp_allclose, comp_pcc, skip_for_grayskull - -from models.tt_transformers.tt.model_config import ModelArgs - - -@torch.no_grad() -@skip_for_grayskull("Requires wormhole_b0 to run") -@pytest.mark.parametrize( - "device", - [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( - os.environ.get("device"), len(ttnn.get_device_ids()) - ) - ], - indirect=True, -) -@pytest.mark.parametrize( - "seq_len", - (128,), -) -@pytest.mark.parametrize( - "batch_size", - (1,), -) -def test_rmsnorm_inference(seq_len, batch_size, reset_seeds, device): - dtype = ttnn.bfloat16 - mode = "decode" if seq_len <= 32 else "prefill" - - tt_model_args = ModelArgs( - device, - max_batch_size=batch_size, - max_seq_len=128, - ) - - tt_model_args.n_layers = 1 - state_dict = tt_model_args.load_state_dict() - - reference_model = tt_model_args.reference_rms_norm() # Gemma3 RMSNorm - - state_dict_prefix = tt_model_args.get_state_dict_prefix("", 0) - first_layer_prefix = state_dict_prefix + "attention_norm." - - partial_state_dict = { - k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) - } - - reference_model.load_state_dict(partial_state_dict) - - tt_inner_norm = RMSNorm( - device=device, - dim=tt_model_args.dim, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_key="attention_norm", - weight_dtype=dtype, - is_distributed=tt_model_args.is_distributed_norm, - sharded_program_config=tt_model_args.get_model_config()["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=tt_model_args.get_model_config()["SHARDED_ATTN_INPUT_MEMCFG"], - ) - - # Wrap it in DistributedNorm - tt_model = DistributedNorm(tt_inner_norm, tt_model_args, TG=tt_model_args.is_galaxy) - - input = torch.rand(1, 1, tt_model_args.dim) - - reference_output = reference_model(input) - - # DistributedNorm inputs are fractured across devices and interleaved in DRAM (for prefill) and L1 (for decode) - tt_input = ttnn.from_torch( - input, - device=device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, -1), mesh_shape=tt_model_args.cluster_shape), - memory_config=( - tt_model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - ), - ) - - tt_output = tt_model(tt_input, mode=mode) - - # DistributedNorm outputs are replicated across devices - tt_output_torch = ttnn.to_torch( - tt_output, - mesh_composer=ttnn.ConcatMesh2dToTensor( - device, dims=(0, 2) if tt_model_args.is_galaxy else (2, 0), mesh_shape=tt_model_args.cluster_shape - ), - )[:1, :, :] - - passing, pcc_message = comp_pcc(reference_output, tt_output_torch[0]) - - logger.info(comp_allclose(reference_output, tt_output_torch[0])) - logger.info(f"PCC: {pcc_message}") - - if passing: - logger.info("rms_norm Passed!") - else: - logger.warning("rms_norm Failed!") - - assert passing, f"rms_norm output does not meet PCC requirement {0.99}." diff --git a/models/experimental/gemma3_1b/tt/attention.py b/models/experimental/gemma3_1b/tt/attention.py deleted file mode 100644 index 338e52dbec82..000000000000 --- a/models/experimental/gemma3_1b/tt/attention.py +++ /dev/null @@ -1,905 +0,0 @@ -""" -This is the attention implementation of the Gemma-3-1b-it - -We have re-used the Attention implementation of the TT-Transformers with few modifications. -This implementation has Changes in Datatype (Bfloat16) that supports the RMSNorm, -Sliding Window support. - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import math - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule - -from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm -from models.tt_transformers.tt.ccl import tt_all_gather, tt_all_reduce -from models.tt_transformers.tt.model_config import OpGroup, TensorGroup - - -class Attention(LightweightModule): - def __init__( - self, - mesh_device, - state_dict, - weight_cache_path, - layer_num, - dtype, - transformation_mats, - configuration, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - super().__init__() - self.is_sliding = bool((layer_num + 1) % configuration.sliding_window_pattern) - - self.state_dict = state_dict - self.mesh_device = mesh_device - self.num_devices = configuration.num_devices - self.TG = self.num_devices == 32 - self.hidden_size = configuration.dim - self.n_heads = configuration.n_heads - self.head_dim = configuration.head_dim - self.max_seq_len = configuration.max_seq_len - self.max_batch_size = configuration.max_batch_size - self.n_kv_heads = configuration.n_kv_heads - self.paged_attention_config = paged_attention_config - self.min_kv_prefill_shard_seqlen = configuration.min_kv_prefill_shard_seqlen - self.ccl_dtype = configuration.ccl_dtype - self.num_reduce_scatter_links = configuration.num_reduce_scatter_links - self.num_all_gather_links = configuration.num_all_gather_links - self.MAX_QKV_MM_SEQ_LEN = configuration.MAX_QKV_MM_SEQ_LEN - self.tile_size = configuration.tile_size - self.rms_norm_add_unit_offset = configuration.rms_norm_add_unit_offset - self.num_device_groups = self.num_devices // self.n_kv_heads - self.num_devices_per_group = self.n_kv_heads if self.TG else self.num_devices - self.batch_size_per_device_group = ( - max(self.max_batch_size // self.num_device_groups, 1) if self.TG else self.max_batch_size - ) - - self.n_local_heads = self.n_heads // self.num_devices_per_group - self.n_local_kv_heads = self.n_kv_heads // self.num_devices_per_group - - self.arch_name = configuration.arch_name - # TODO: Fix this once all-gather supports < tile_size - if self.TG: - weight = torch.zeros(1, 32, 8, 32) - for i in range(32): - col = i % 4 # This determines which group of 8 to select - weight[:, i, :, col * 8 : (col + 1) * 8] = torch.eye(8) - - self.slice_mat = ttnn.from_torch( - weight, - dtype=ttnn.bfloat4_b, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), - ) - user_selection_matrix = torch.eye(8, 8) - user_selection_matrix = torch.nn.functional.pad(user_selection_matrix, (0, 24), "constant", 0) # (8, 32) - user_selection_matrix = [user_selection_matrix] * 4 - user_selection_matrix = torch.block_diag(*user_selection_matrix) # (32, 128) - self.user_selection_matrix = ttnn.from_torch( - user_selection_matrix, - dtype=ttnn.bfloat4_b, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - - self.dtype = dtype - - self.max_seq_len = configuration.max_seq_len - self.grid_size = configuration.max_grid_size - - self.compute_kernel_config_hifi2 = configuration.compute_kernel_config_hifi2 - self.compute_kernel_config_hifi2_fp16 = configuration.compute_kernel_config_hifi2_fp16 - - self.compute_kernel_config_hifi4 = configuration.compute_kernel_config_hifi4 - - self.transformation_mats = transformation_mats - - self.model_config = configuration.get_model_config() - self.ccl_topology = configuration.ccl_topology() - self.is_multichip = configuration.is_multichip - self.activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.ACTIVATION - ) - self.wqkv_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.WQKV - ) - self.wo_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.WO - ) - self.kv_cache_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.KV_CACHE - ) - self.li_qkv_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - self.sdpa_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - self.li_o_decode_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - self.sdpa_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - self.li_qkv_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - self.li_o_prefill_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=configuration - ) - - layer_name = configuration.get_state_dict_prefix(self.__class__.__name__, layer_num) - if configuration.dummy_weights or (weight_cache_path is None): - cache_name = lambda _: None - else: - cache_name = lambda name: weight_cache_path / (f"{layer_name}.{name}") - - wq_str = f"{layer_name}.wq" - wk_str = f"{layer_name}.wk" - wv_str = f"{layer_name}.wv" - wo_str = f"{layer_name}.wo" - q_norm_str = f"{layer_name}.q_norm" - k_norm_str = f"{layer_name}.k_norm" - - # Initialize bias tensors as None - self.wqkv_bias_decode = None - self.wqkv_bias_prefill = None - - # Create combined QKV bias if present in state dict - if f"{wq_str}.bias" in self.state_dict: - qkv_bias = torch.concat( - [ - torch.concat( - [ - torch.chunk(self.state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], - torch.chunk(self.state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], - torch.chunk(self.state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], - ], - dim=-1, - ) - for i in range(configuration.num_devices) - ], - dim=-1, - ) - # Prefill can use broadcasting on the bias add so wants a 1d tensor - self.wqkv_bias_prefill = ttnn.as_tensor( - qkv_bias, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - cache_file_name=cache_name("wqkv_bias_prefill_sharded"), - ) - # as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size - self.wqkv_bias_prefill = ttnn.reshape( - self.wqkv_bias_prefill, - (1, 1, 1, self.wqkv_bias_prefill.shape[-1]), - (1, 1, self.wqkv_bias_prefill.shape[-2], self.wqkv_bias_prefill.shape[-1]), - ) - - # Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size - # Create a list of bias tensors for each multiple of tile_size up to max_batch_size - self.wqkv_bias_decode = [] - for batch_size in range( - configuration.tile_size, - configuration.tile_padded_batch_rows + configuration.tile_size, - configuration.tile_size, - ): - qkv_bias_decode = qkv_bias.unsqueeze(0).expand(batch_size, -1) - bias_tensor = ttnn.as_tensor( - qkv_bias_decode, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), - dtype=ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - layout=ttnn.TILE_LAYOUT, - cache_file_name=cache_name(f"wqkv_bias_decode_sharded_{batch_size}"), - ) - self.wqkv_bias_decode.append(bias_tensor) - - # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices - assert self.n_heads % self.num_devices_per_group == 0 - assert self.n_kv_heads % self.num_devices_per_group == 0 - assert configuration.qkv_size % self.num_devices_per_group == 0 - assert configuration.dim % self.num_devices_per_group == 0 - - # wqkv: 4096 x 3072 (2 devices): width-sharded on 12 banks, 3072 over 12 banks. - wqkv_mem_config = configuration.create_dram_sharded_mem_config( - configuration.dim, configuration.qkv_size // configuration.num_devices - ) - - qkv_list = [] - for i in range(self.num_devices_per_group): - # Chunk weights - wq_selected = torch.chunk(self.state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] - wk_selected = torch.chunk(self.state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] - wv_selected = torch.chunk(self.state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] - - # Transpose the selected chunks - wq = torch.transpose(wq_selected, -2, -1) - wk = torch.transpose(wk_selected, -2, -1) - wv = torch.transpose(wv_selected, -2, -1) - - qkv = torch.cat([wq, wk, wv], dim=-1) - qkv_list.append(qkv) - - qkv_cat = torch.cat(qkv_list, dim=-1).unsqueeze(0).unsqueeze(0) - - self.wqkv = ttnn.as_tensor( - qkv_cat, - dtype=self.wqkv_dtype, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG if self.TG else wqkv_mem_config, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, dims=(3, 2) if self.TG else (2, 3), mesh_shape=configuration.cluster_shape - ), - cache_file_name=cache_name("wqkv_sharded_2d"), - ) - - def norm_reshard(x, norm, mode): - """Hack until RMSNorm supports height-sharded output config""" - if mode == "decode": - mem_cfg = x.memory_config() - x = ttnn.to_memory_config(x, ttnn.L1_MEMORY_CONFIG, dtype=x.dtype) - x = norm(x, mode) - if mode == "decode": - x = ttnn.to_memory_config(x, mem_cfg, dtype=x.dtype) - return x - - if f"{q_norm_str}.weight" in self.state_dict: - fn_q_norm = RMSNorm( - device=self.mesh_device, - dim=self.head_dim, - eps=configuration.norm_eps, - state_dict=self.state_dict, - state_dict_prefix=None, # we already prefix q_norm_str - weight_cache_path=None if configuration.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key=q_norm_str, - add_unit_offset=self.rms_norm_add_unit_offset, - is_distributed=False, - sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"] - ) - self.q_norm = lambda x, mode: norm_reshard(x, fn_q_norm, mode) - else: - self.q_norm = lambda x, mode: x - - if f"{k_norm_str}.weight" in self.state_dict: - fn_k_norm = RMSNorm( - device=self.mesh_device, - dim=self.head_dim, - eps=configuration.norm_eps, - state_dict=self.state_dict, - state_dict_prefix=None, # we already prefix k_norm_str - weight_cache_path=None if configuration.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key=k_norm_str, - add_unit_offset=self.rms_norm_add_unit_offset, - is_distributed=False, - sharded_program_config=None, # FIXME: add height-sharded support. self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=None, # FIXME: add height-sharded support. self.model_config["CREATE_QKV_DECODE_SHARD"], - ) - self.k_norm = lambda x, mode: norm_reshard(x, fn_k_norm, mode) - else: - self.k_norm = lambda x, mode: x - - # For ring topology we can use all gather matmul for wo - self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] - pt_wo = self.state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) - - wo_mem_config = configuration.create_dram_sharded_mem_config( - (configuration.n_heads * configuration.head_dim) // configuration.num_devices, configuration.dim - ) - - self.wo = ttnn.as_tensor( - pt_wo, - dtype=self.wo_dtype, - layout=ttnn.TILE_LAYOUT, - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG if (self.use_fused_all_gather_matmul or self.TG) else wo_mem_config, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(2, 3) if (self.use_fused_all_gather_matmul or self.TG) else (3, 2), - mesh_shape=configuration.cluster_shape, - ), - cache_file_name=( - cache_name("wo_width_sharded_2d") if (self.use_fused_all_gather_matmul or self.TG) else cache_name("wo") - ), - ) - if not use_paged_kv_cache: - # vLLM provides its own kv cache - self.init_kv_cache(configuration, weight_cache_path) - - if configuration.query_pre_attn_scalar is not None: - self.scale = configuration.query_pre_attn_scalar**-0.5 - else: - self.scale = self.head_dim**-0.5 - - def init_kv_cache(self, configuration, weight_cache_path): - """ - Generates empty KV cache and pushed to device memory - """ - - if self.paged_attention_config: - cache_k = torch.zeros( - ( - self.paged_attention_config.max_num_blocks, - self.n_local_kv_heads, - self.paged_attention_config.block_size, - self.head_dim, - ) - ) - cache_v = torch.zeros( - ( - self.paged_attention_config.max_num_blocks, - self.n_local_kv_heads, - self.paged_attention_config.block_size, - self.head_dim, - ) - ) - else: - cache_k = torch.zeros( - ( - self.batch_size_per_device_group, - self.n_local_kv_heads, - self.max_seq_len, - self.head_dim, - ) - ) - cache_v = torch.zeros( - ( - self.batch_size_per_device_group, - self.n_local_kv_heads, - self.max_seq_len, - self.head_dim, - ) - ) - - self.layer_past = [ - ttnn.as_tensor( - k_or_v, - dtype=self.kv_cache_dtype, - layout=self.model_config["ATTN_W_LAYOUT_TILE"], - device=self.mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - cache_file_name=( - f"{weight_cache_path}/kvcache_{k_or_v.shape}" - if weight_cache_path and not configuration.dummy_weights - else None - ), - ) - for k_or_v in [cache_k, cache_v] - ] - - def forward_decode( - self, - x: ttnn.Tensor, - current_pos, - rot_mats=None, - page_table=None, - kv_cache=None, - ) -> ttnn.Tensor: - """ - x: (seq_len, 1, batch, dim) - current_pos: (batch_size), current token position in the sequence for each user - """ - - ### - # QKV matmuls - # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. - ### - - xqkv_fused_sharded = ttnn.linear( - x, - self.wqkv, - # bias=self.wqkv_bias, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - program_config=self.model_config["XQKV_DECODE_PROGCFG"], - compute_kernel_config=self.li_qkv_decode_compute_kernel_cfg, - dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, - ) - # FIXME: File bug against dram-sharded matmuls with bias - if self.wqkv_bias_decode: - # select the bias tensor based on the number of tiles in the rows - # WARNING: must not change the batch size between compiling and executing a trace - num_tiles = int(math.ceil(xqkv_fused_sharded.shape[-2] / self.tile_size)) - xqkv_fused_sharded = xqkv_fused_sharded + self.wqkv_bias_decode[num_tiles - 1] - - ttnn.deallocate(x) - xqkv_fused = tt_all_reduce( - xqkv_fused_sharded, - self.mesh_device, - cluster_axis=1, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - memory_config=self.model_config["QKV_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[1]), - sharded=True, - dtype=self.ccl_dtype, - topology=self.ccl_topology, - ) - - if self.TG: - # TODO: Slice the fused_query_key_value tensor get batch=8 - xqkv_fused = ttnn.matmul( - self.slice_mat, - xqkv_fused, - dtype=ttnn.bfloat16, - memory_config=self.model_config["CREATE_HEAD_INPUT_MEMCFG"], - ) - else: - # bfloat16 is required by nlp_create_qkv_heads_decode - xqkv_fused = ttnn.sharded_to_interleaved(xqkv_fused_sharded, ttnn.L1_MEMORY_CONFIG, ttnn.bfloat16) - - ttnn.deallocate(xqkv_fused_sharded) - - # Reshape such that true unpadded batch is tracked in shape - fqkv_shape = xqkv_fused.shape - xqkv_fused = ttnn.reshape( - xqkv_fused, (1, 1, self.batch_size_per_device_group, fqkv_shape[3]), (1, 1, 32, fqkv_shape[3]) - ) - - ### - # Reshape and rotary embeddings - ### - ( - q_heads_pre_rot_1BQD, - k_heads_pre_rot_1BKD, - v_heads_1BKD, - ) = ttnn.experimental.nlp_create_qkv_heads_decode( - xqkv_fused, - num_heads=self.n_local_heads, - num_kv_heads=self.n_local_kv_heads, - memory_config=self.model_config["CREATE_QKV_DECODE_SHARD"], - ) - - q_heads_pre_rot_1BQD = self.q_norm(q_heads_pre_rot_1BQD, mode="decode") - k_heads_pre_rot_1BKD = self.k_norm(k_heads_pre_rot_1BKD, mode="decode") - - ttnn.deallocate(xqkv_fused) - - # Q Rotary Embeddings - q_heads_1BQD = ttnn.experimental.rotary_embedding_llama( - q_heads_pre_rot_1BQD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True - ) - - # K Rotary Embeddings - k_heads_1BKD = ttnn.experimental.rotary_embedding_llama( - k_heads_pre_rot_1BKD, rot_mats[0], rot_mats[1], self.transformation_mats["decode"], is_decode_mode=True - ) - - ttnn.deallocate(q_heads_pre_rot_1BQD) - ttnn.deallocate(k_heads_pre_rot_1BKD) - - ### - # KV update - ### - if kv_cache: - keys = kv_cache[0] - values = kv_cache[1] - else: - keys = self.layer_past[0] - values = self.layer_past[1] - # k_heads, [seqlen, n_kv_heads, bsz, head_dim] - # v_heads [seqlen, n_kv_heads, bsz, head_dim] - # keys, [max_batch_size, n_kv_heads // configuration.num_devices, max_seq_len, head_dim] - ttnn.experimental.paged_update_cache(keys, k_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table) - ttnn.experimental.paged_update_cache( - values, v_heads_1BKD, update_idxs_tensor=current_pos, page_table=page_table - ) - - ttnn.deallocate(k_heads_1BKD) - ttnn.deallocate(v_heads_1BKD) - - # NOTE: Varying the batch size will result in slightly different outputs. - # For example, a prompt w/ 1 user vs, the same prompt repeated N times for N users, will produce different outputs - # This is because the SDPA op in decode mode has different number of reductions depending on batch size - # Which leads to slightly different outputs from attention (due to accumulated errors) - q_heads_1BQD = ttnn.to_memory_config(q_heads_1BQD, ttnn.DRAM_MEMORY_CONFIG) - if page_table: - attn_output_1G4D = ttnn.transformer.paged_scaled_dot_product_attention_decode( - q_heads_1BQD, - keys, - values, - cur_pos_tensor=current_pos, - page_table_tensor=page_table, - scale=self.scale, - program_config=self.model_config["SDPA_DECODE_PROGCFG"], - compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - else: - attn_output_1G4D = ttnn.transformer.scaled_dot_product_attention_decode( - q_heads_1BQD, - keys, - values, - cur_pos_tensor=current_pos, - scale=self.scale, - program_config=self.model_config["SDPA_DECODE_PROGCFG"], - compute_kernel_config=self.sdpa_decode_compute_kernel_cfg, - memory_config=ttnn.DRAM_MEMORY_CONFIG, # FIXME: why not L1 height sharded e.g. SCORES_BATCHED_MM_OUTPUT_MEMCFG? - ) - - ttnn.deallocate(q_heads_1BQD) - - attn_output_11BH = ttnn.to_memory_config( - attn_output_1G4D, - memory_config=self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"](self.batch_size_per_device_group), - ) - attn_output_cat = ttnn.experimental.nlp_concat_heads_decode( - attn_output_11BH, - num_heads=self.n_local_heads, - ) - ttnn.deallocate(attn_output_11BH) - ttnn.deallocate(attn_output_1G4D) - - if self.use_fused_all_gather_matmul: - attn_output_cat = ttnn.to_memory_config( - attn_output_cat, self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"] - ) - _, dense_out_sharded, _ = ttnn.experimental.all_gather_matmul( - attn_output_cat, - self.wo, - dim=3, - all_gather_core_grid_offset=(0, 4), - num_links=1, - program_config=self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"], - compute_kernel_config=self.li_o_decode_compute_kernel_cfg, - memory_config_ag=self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"], - memory_config_mm=self.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - ttnn.deallocate(attn_output_cat) - dense_out_sharded = ttnn.to_memory_config(dense_out_sharded, self.model_config["DECODE_RESIDUAL_MEMCFG"]) - return dense_out_sharded - - else: - attn_output = tt_all_gather( - attn_output_cat, - self.mesh_device, - dim=2, - cluster_axis=1, - num_links=2, - memory_config=self.model_config["GATHER_USERS_MEMCFG"](list(self.mesh_device.shape)[1]), - sharded=True, - # dtype=self.ccl_dtype, # Running bf16 until we have SDPA output bfp8 df; otherwise we have two sharded to interleaved/interleaved to sharded conversions - ) - if self.TG: - attn_output = ttnn.to_memory_config(attn_output, ttnn.L1_MEMORY_CONFIG) - # user_selection_matrix = [1, 1, 32, 128] - # user_selection_matrix @ activation -> [1, 1, 32, 128] * [1, 1, 128, 2048] -> [1, 1, 32, 2048] - attn_output = ttnn.matmul( - self.user_selection_matrix, - attn_output, - core_grid=ttnn.CoreGrid(y=4, x=8), - dtype=ttnn.bfloat16, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - ) - - # TODO: Fix this once self.TG supports dram-sharded matmuls - dense_out_sharded = ttnn.matmul( - attn_output, - self.wo, - core_grid=ttnn.CoreGrid(y=4, x=8) if self.TG else None, - program_config=self.model_config["ATTN_OUTPUT_PROGCFG"] if not self.TG else None, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b if self.TG else ttnn.bfloat16, - compute_kernel_config=self.li_o_decode_compute_kernel_cfg, - ) - - ttnn.deallocate(attn_output_cat) - - # All reduce - dense_out_reduced = tt_all_reduce( - dense_out_sharded, - self.mesh_device, - cluster_axis=0, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - dim=0 if (self.TG and self.hidden_size < 8192) else 3, - topology=self.ccl_topology, - memory_config=( - ( - self.model_config["SELF_OUT_REDUCE_SCATTER_MEMCFG"] - if self.hidden_size == 8192 - else self.model_config["SELF_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[0]) - ) - if self.TG - else self.model_config["DECODE_RESIDUAL_MEMCFG"] - ), - sharded=True, - dtype=self.ccl_dtype, - use_composite=True if self.hidden_size == 8192 else False, - ) - - if not self.TG: - dense_out_reduced = ttnn.to_memory_config( - dense_out_reduced, self.model_config["DECODE_RESIDUAL_MEMCFG"] - ) - - return dense_out_reduced - - def forward_prefill( - self, - x_11SH, - rot_mats, - user_id: int = 0, - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ): - seq_len = x_11SH.shape[-2] - assert seq_len % 128 == 0 and seq_len > 0, "Seqlen must be divisible by 128" - ### - # QKV matmuls - ### - - # reshaping long sequence to matmul fit on device - if seq_len > self.MAX_QKV_MM_SEQ_LEN: - if seq_len % self.MAX_QKV_MM_SEQ_LEN != 0: - raise ValueError(f"seq_len {seq_len} must be divisible by {self.MAX_QKV_MM_SEQ_LEN}") - x_11SH = ttnn.reshape(x_11SH, [1, seq_len // self.MAX_QKV_MM_SEQ_LEN, self.MAX_QKV_MM_SEQ_LEN, -1]) - - xqkv_fused = ttnn.linear( - x_11SH, - self.wqkv, - dtype=self.ccl_dtype if self.TG else self.activation_dtype or ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - compute_kernel_config=self.li_qkv_prefill_compute_kernel_cfg, - program_config=self.model_config["XQKV_PREFILL_PROGCFG"](seq_len), - ) - - # FIXME: surely ttnn.linear bias should work? - if self.wqkv_bias_prefill is not None: - xqkv_fused = xqkv_fused + self.wqkv_bias_prefill - - xqkv_fused = tt_all_reduce( - xqkv_fused, - self.mesh_device, - cluster_axis=1, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.ccl_dtype, - ) - - if seq_len > self.MAX_QKV_MM_SEQ_LEN: - xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) - - ttnn.deallocate(x_11SH) - - # split qkv into heads - ( - q_heads_1QSD_pre_rot, - k_heads_1KSD_pre_rot, - v_heads_1VSD, - ) = ttnn.experimental.nlp_create_qkv_heads( - xqkv_fused, - num_heads=self.n_local_heads, - num_kv_heads=self.n_local_kv_heads, - transpose_k_heads=False, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - - q_heads_1QSD_pre_rot = self.q_norm(q_heads_1QSD_pre_rot, mode="prefill") - k_heads_1KSD_pre_rot = self.k_norm(k_heads_1KSD_pre_rot, mode="prefill") - - ttnn.deallocate(xqkv_fused) - - ### - # Rotary embeddings - ### - - if q_heads_1QSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs - q_heads_1QSD_pre_rot = ttnn.typecast(q_heads_1QSD_pre_rot, dtype=ttnn.bfloat16) - - q_heads_1QSD = ttnn.experimental.rotary_embedding_llama( - q_heads_1QSD_pre_rot, - rot_mats[0], - rot_mats[1], - self.transformation_mats["prefill"], - is_decode_mode=False, - ) - ttnn.deallocate(q_heads_1QSD_pre_rot) - - if k_heads_1KSD_pre_rot.dtype != ttnn.bfloat16: # Rotary embeddings require bfloat16 inputs - k_heads_1KSD_pre_rot = ttnn.typecast(k_heads_1KSD_pre_rot, dtype=ttnn.bfloat16) - - k_heads_1KSD = ttnn.experimental.rotary_embedding_llama( - k_heads_1KSD_pre_rot, - rot_mats[0], - rot_mats[1], - self.transformation_mats["prefill"], - is_decode_mode=False, - ) - ttnn.deallocate(k_heads_1KSD_pre_rot) - - # Fill KV-Cache - if kv_cache: - keys_BKSD, values_BKSD = kv_cache[0], kv_cache[1] - else: - keys_BKSD, values_BKSD = self.layer_past[0], self.layer_past[1] - k_heads_1KSD_8b = ttnn.typecast(k_heads_1KSD, dtype=keys_BKSD.dtype) - ttnn.deallocate(k_heads_1KSD) - - # sharding k_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - k_fill = ttnn.interleaved_to_sharded(k_heads_1KSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - k_fill = k_heads_1KSD_8b - - v_heads_1VSD_8b = ttnn.typecast(v_heads_1VSD, dtype=values_BKSD.dtype) - - ttnn.deallocate(v_heads_1VSD) - - # sharding v_fill to deal with update_cache memory limitation - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - v_fill = ttnn.interleaved_to_sharded(v_heads_1VSD_8b, self.model_config["KV_PREFILL_MEM_CFG"](seq_len)) - else: - v_fill = v_heads_1VSD_8b - - if self.TG: - k_fill = self.prefill_prepare_tensor_for_kv_cache(k_fill, user_id) - v_fill = self.prefill_prepare_tensor_for_kv_cache(v_fill, user_id) - if page_table: - # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values. - # Assume that the page table does not have padding, so we can use it to get the unpadded page len. - block_size = keys_BKSD.shape[2] - # If chunked prefill, use chunk_page_table if given, otherwise use page_table. - fill_page_table = chunk_page_table if chunk_page_table is not None else page_table - - page_len = fill_page_table.shape[1] * block_size - k_fill_sliced = k_fill[:, :, :page_len, :] if page_len < k_fill.shape[2] else k_fill - v_fill_sliced = v_fill[:, :, :page_len, :] if page_len < v_fill.shape[2] else v_fill - ttnn.experimental.paged_fill_cache(keys_BKSD, k_fill_sliced, fill_page_table, batch_idx=user_id) - ttnn.experimental.paged_fill_cache(values_BKSD, v_fill_sliced, fill_page_table, batch_idx=user_id) - else: - ttnn.fill_cache( - keys_BKSD, - k_fill, - user_id % self.batch_size_per_device_group, - ) - ttnn.fill_cache( - values_BKSD, - v_fill, - user_id % self.batch_size_per_device_group, - ) - - if seq_len >= self.min_kv_prefill_shard_seqlen and not self.TG and not page_table: - ttnn.deallocate(k_fill) - ttnn.deallocate(v_fill) - - # SDPA - q_heads_1QSD_8b = ttnn.typecast(q_heads_1QSD, dtype=self.activation_dtype or ttnn.bfloat16) - ttnn.deallocate(q_heads_1QSD) - - if chunk_start_idx is not None: - attn_output_84SD = ttnn.transformer.chunked_scaled_dot_product_attention( - q_heads_1QSD_8b, - keys_BKSD, - values_BKSD, - page_table, - chunk_start_idx, - compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, - program_config=self.model_config["SDPA_PROGCFG"](seq_len), - ) - else: - attn_output_84SD = ttnn.transformer.scaled_dot_product_attention( - q_heads_1QSD_8b, - k_heads_1KSD_8b, - v_heads_1VSD_8b, - is_causal=True, - scale=self.scale, - compute_kernel_config=self.sdpa_prefill_compute_kernel_cfg, - program_config=self.model_config["SDPA_PROGCFG"](seq_len), - ) - - # deallocate keys and values - ttnn.deallocate(q_heads_1QSD_8b) - ttnn.deallocate(k_heads_1KSD_8b) - ttnn.deallocate(v_heads_1VSD_8b) - - attn_output_1QSD = ttnn.reshape(attn_output_84SD, [1, self.n_local_heads, -1, self.head_dim]) - - ### - # Output matmul - ### - attn_output_11SH = ttnn.experimental.nlp_concat_heads( - attn_output_1QSD, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - ttnn.deallocate(attn_output_1QSD) - # reshaping long sequence to matmul fit on device - if seq_len > 1024: - attn_output_11SH = ttnn.reshape(attn_output_11SH, [1, seq_len // 1024, 1024, -1]) - - # Non fused All Gather Matmul - if self.use_fused_all_gather_matmul: # is true for Ring topology - attn_output_11SH = ttnn.all_gather( - attn_output_11SH, - dim=3, - num_links=1, - topology=self.ccl_topology, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - - output_11SH = ttnn.linear( - attn_output_11SH, - self.wo, - compute_kernel_config=self.li_o_prefill_compute_kernel_cfg, - dtype=self.activation_dtype or ttnn.bfloat16, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - program_config=self.model_config["WO_PREFILL_PROGCFG"](seq_len), - ) - - if seq_len > 1024: - output_11SH = ttnn.reshape(output_11SH, [1, 1, seq_len, -1]) - ttnn.deallocate(attn_output_11SH) - - # Reduce-scatter - if not self.use_fused_all_gather_matmul: - output_11SH = tt_all_reduce( - output_11SH, - self.mesh_device, - cluster_axis=0, - dim=0 if self.TG else 3, - num_reduce_scatter_links=self.num_reduce_scatter_links, - num_all_gather_links=self.num_all_gather_links, - topology=self.ccl_topology, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=self.ccl_dtype, - ) - - return output_11SH - - def forward( - self, - x, - current_pos, - rot_mats=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ): - if mode == "prefill": - return self.forward_prefill( - x, - rot_mats, - user_id, - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache, - ) - else: - return self.forward_decode(x, current_pos, rot_mats, page_table=page_table, kv_cache=kv_cache) - - def prefill_prepare_tensor_for_kv_cache(self, key_or_value_layer, user_id): - tensor_copy = ttnn.clone(key_or_value_layer) - # key_or_value_layer.deallocate(True) - # Get all tensors from multi-device tensor - tensors = ttnn.get_device_tensors(tensor_copy) - # Get only tensors from specific column chips - # Get every 4th tensor starting from user_id // 8 - single_column_tensors = tensors[user_id // self.batch_size_per_device_group :: 4] - # Create multi-device tensor - multi_device_tensor = ttnn.combine_device_tensors(single_column_tensors) - - return multi_device_tensor diff --git a/models/experimental/gemma3_1b/tt/decoder.py b/models/experimental/gemma3_1b/tt/decoder.py deleted file mode 100644 index 2cf46643c599..000000000000 --- a/models/experimental/gemma3_1b/tt/decoder.py +++ /dev/null @@ -1,226 +0,0 @@ -""" - -This is the Decoder block for the Gemma 3-1b-it model -We couldn't use the existing implementation in TT-Transformers because the usage of submodules is different - -In Gemma-3-1b-it, The decoder Block has Additional pre_feedforward_layernorm and post_feedforward_layernorm, -And the logic of implementation is different from the existing implementation in TT-Transformers. - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn - -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.distributed_norm import DistributedNorm -from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm - -from models.experimental.gemma3_1b.tt.attention import Attention - -from models.experimental.gemma3_1b.tt.mlp import MLP -from models.tt_transformers.tt.model_config import TensorGroup - - -class TransformerBlock(LightweightModule): - def __init__( - self, - args, - mesh_device, - dtype, - state_dict, - layer_num, - weight_cache_path, - transformation_mats, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - super().__init__() - - self.state_dict = state_dict - self.mesh_device = mesh_device - - self.args = args - self.hidden_size = args.dim - self.n_heads = args.n_heads - self.head_dim = self.hidden_size // self.n_heads - self.max_seq_len = args.max_seq_len - self.dim = args.dim - self.max_batch_size = args.max_batch_size - self.n_kv_heads = args.n_kv_heads - self.current = 0 - self.model_config = args.get_model_config() - - self.layer_num = layer_num - - self.attention = Attention( - mesh_device=mesh_device, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - transformation_mats=transformation_mats, - configuration=args, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - ) - - self.feed_forward = MLP( - mesh_device=mesh_device, - args=args, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=layer_num, - dtype=dtype, - model_config=self.model_config, - ) - - self.attention_norm = DistributedNorm( # input_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="attention_norm", - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_ATTN_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - ), - args, - TG=args.is_galaxy, - ) - - self.ff_norm = DistributedNorm( # post_attention_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="ffn_norm", - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - ), - args, - TG=args.is_galaxy, - ) - - self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="pre_feedforward_layernorm", - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - ), - args, - TG=args.is_galaxy, - ) - - self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", layer_num), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="post_feedforward_layernorm", - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], - sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - ), - args, - TG=args.is_galaxy, - ) - - def forward( - self, - hidden_states: ttnn.Tensor, - current_pos, - rot_mats=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - kv_cache=None, - ): - TG = self.args.is_galaxy - skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - - assert ( - hidden_states.memory_config() == skip_mem_cfg - ), f"decoder input memcfg mismatch: {hidden_states.memory_config()} != {skip_mem_cfg}" - residual = hidden_states - - attn_in = self.attention_norm(hidden_states, mode) - - if self.attention.is_sliding: - position_embeddings = rot_mats[1] - else: - position_embeddings = rot_mats[0] - - attn_out = self.attention.forward( - attn_in, - current_pos, - position_embeddings, - user_id, - mode, - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache, - ) - - hidden_states = self.ff_norm(attn_out, mode) - - ttnn.deallocate(attn_out) - ttnn.deallocate(attn_in) - - hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) - - residual = hidden_states - - hidden_states = self.pre_ff_norm(hidden_states, mode) - - if TG and mode == "decode": - hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) - - hidden_states = self.feed_forward.forward(hidden_states, mode) - - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION - ) - - hidden_states = self.post_ff_norm(hidden_states, mode) - - hidden_states = ttnn.add( - hidden_states, - residual, - memory_config=skip_mem_cfg, - dtype=self.args.ccl_dtype - if TG and not self.args.is_distributed_norm(mode) - else activation_dtype or ttnn.bfloat16, - ) - - return hidden_states diff --git a/models/experimental/gemma3_1b/tt/lm_head.py b/models/experimental/gemma3_1b/tt/lm_head.py deleted file mode 100644 index 5a62229111c4..000000000000 --- a/models/experimental/gemma3_1b/tt/lm_head.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -This is the implementation of lm_head of the Gemma-3-1b-it. - -We have re-used the lm_head implementation of the TT-Transformers library along with few modifications. -This implementation has changes in Memory Configurations (DRAM Memory Config) and Data Type (bfloat16). -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import math - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.ccl import tt_all_reduce - - -class LMHead(LightweightModule): - def __init__( - self, - args, - mesh_device, - dtype, - state_dict, - state_dict_prefix, - weight_cache_path, - max_columns_per_device, # too many columns per device lead to L1 OOM - ): - super().__init__() - self.args = args - self.mesh_device = mesh_device - self.dtype = dtype - self.vocab_size = args.vocab_size - self.padded_vocab_size = args.padded_vocab_size - self.num_devices = args.num_devices - - size_per_device = self.vocab_size // self.num_devices - - if args.is_galaxy: - size_per_device = self.padded_vocab_size // self.num_devices - num_splits = math.ceil(size_per_device / max_columns_per_device) - - split_sizes = [min(size_per_device, max_columns_per_device)] * (num_splits - 1) - split_sizes.append(size_per_device - sum(split_sizes)) # remaining columns - - # Split the output weights - torch_output_weights = state_dict[f"{state_dict_prefix}output.weight"].permute(1, 0) - - self.output_weights = [] - if args.is_galaxy: - cache_file_name = ( - None if args.dummy_weights else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_0" - ) - padded_lm_head = torch.zeros(1, 1, args.dim, self.padded_vocab_size) - padded_lm_head[:, :, :, : self.vocab_size] = torch_output_weights - - memory_config = ( - ttnn.DRAM_MEMORY_CONFIG - if args.dim == 2048 - else args.create_dram_sharded_mem_config(k=args.dim // 4, n=self.padded_vocab_size // 8) - ) - self.output_weights.append( # (2k, 16k) 128* 1024 - ttnn.as_tensor( - padded_lm_head, - device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(3, 2), mesh_shape=args.cluster_shape), - layout=ttnn.TILE_LAYOUT, - dtype=dtype, - memory_config=memory_config, - cache_file_name=cache_file_name, - ) - ) - else: - for i, split_size in enumerate(split_sizes): - # Create a list to store the split tensors for each device - device_splits = [] - for device in range(self.num_devices): - start = device * size_per_device + sum(split_sizes[:i]) - end = start + split_size - device_splits.append(torch_output_weights[:, start:end]) - - # Concatenate the splits from all devices - combined_split = torch.cat(device_splits, dim=-1) - - cache_file_name = ( - None - if args.dummy_weights - else weight_cache_path / f"output_lm_head_{num_splits}_split_shard_{i}_{combined_split.shape[-1]}" - ) - memory_config = args.create_dram_sharded_mem_config( - k=args.dim, n=math.ceil(combined_split.shape[-1] / self.num_devices) - ) - self.output_weights.append( - ttnn.as_tensor( - combined_split, - device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), - layout=ttnn.TILE_LAYOUT, - dtype=dtype, - memory_config=memory_config, - cache_file_name=cache_file_name, - ) - ) - - self.compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - fp32_dest_acc_en=True, - packer_l1_acc=True, - dst_full_sync_en=False, - ) - - if args.is_galaxy: - self.program_configs = [ - ( - None - if args.dim == 2048 - else args.dram_matmul_config( - args.tile_padded_batch_rows, # (8k, 128k) -> (2k, 16k) - args.dim // 4, - 16 * 1024, - args.lm_head_core_grid.num_cores, - ) - ) - ] - - else: - self.program_configs = [ - args.dram_matmul_config( - args.tile_padded_batch_rows, - args.dim, - split_size, - args.lm_head_core_grid.num_cores, - ) - for split_size in split_sizes - ] - - def forward(self, x: ttnn.Tensor): - outputs = [] - for weight, pc in zip(self.output_weights, self.program_configs): - output = ttnn.linear( - x, - weight, - compute_kernel_config=self.compute_kernel_config, - program_config=pc, - memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.DRAM_MEMORY_CONFIG)) - - # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.DRAM_MEMORY_CONFIG) - - output = tt_all_reduce( - output, - mesh_device=self.mesh_device, - cluster_axis=1, - dim=3 if self.args.is_galaxy else 0, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=self.args.ccl_dtype, - sharded=False, - use_composite=True, - ) - - return output diff --git a/models/experimental/gemma3_1b/tt/mlp.py b/models/experimental/gemma3_1b/tt/mlp.py deleted file mode 100644 index 2abd227fb2ca..000000000000 --- a/models/experimental/gemma3_1b/tt/mlp.py +++ /dev/null @@ -1,264 +0,0 @@ -""" -This is the implementation of MLP (feed-forward) submodule of Gemma-3-1b-it. - -We have re-used the MLP implementation of the TT-Transformers library with few modifications. -This implementation has changes in Data Type (bfloat16). -""" - - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch - -import ttnn -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.ccl import tt_all_reduce -from models.tt_transformers.tt.common import pad_to_size -from models.tt_transformers.tt.model_config import OpGroup, TensorGroup - - -class MLP(LightweightModule): - def __init__( - self, mesh_device, args, state_dict, weight_cache_path, layer_num, dtype, model_config, state_dict_prefix=None - ): - super().__init__() - - self.state_dict = state_dict - self.mesh_device = mesh_device - self.args = args - self.dim = args.dim - self.model_config = model_config - self.layer_num = layer_num - state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) - torch_weight = lambda name: torch.transpose(self.state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) - pad_hidden_dim = lambda tensor, dim: pad_to_size(tensor, dim=dim, size=args.hidden_dim) - # If pading was applied (e.g. via env var), add the unpadded hidden dim to the cache name to avoid loading incorrect weights - hidden_dim_string = f".hidden_dim_{args.hidden_dim}" if args.hidden_dim != args.unpadded_hidden_dim else "" - - if args.dummy_weights: - cache_name = lambda _: None - else: - cache_name = lambda name: weight_cache_path / f"{state_dict_prefix}.{name}{hidden_dim_string}" - - w1_w3_mem_config = args.create_dram_sharded_mem_config(args.dim, args.hidden_dim // args.num_devices) - w2_mem_config = args.create_dram_sharded_mem_config(args.hidden_dim // args.num_devices, args.dim) - - # TODO Clean up this code. With sharding, we load the normal weights and then shard them - as_sharded_tensor = lambda name, type, dims: ttnn.as_tensor( - pad_hidden_dim( - torch_weight(name[:2]), dims[0] if args.is_galaxy else dims[-1] - ), # Grab only the wX part of the name - dtype=type, - device=self.mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=args.cluster_shape), - layout=ttnn.TILE_LAYOUT, - memory_config=( - ttnn.DRAM_MEMORY_CONFIG if args.is_galaxy else w2_mem_config if "w2" in name else w1_w3_mem_config - ), - cache_file_name=cache_name(name), - ) - - # Sharded weights - w1_dims = (-1, -2) if args.is_galaxy else (-2, -1) - w2_dims = (-2, -1) if args.is_galaxy else (-1, -2) - - layer_num = max(layer_num, 0) # cross_block uses the configutation of the first decoder - - ff1_3_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.FF1_FF3 - ) - ff2_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.FF2 - ) - - self.w1 = as_sharded_tensor( - "w1_sharded", ff1_3_dtype, dims=w1_dims - ) # bfp4 normally ok here but sub .99 pcc for llama 3.1 weights - self.w2 = as_sharded_tensor("w2_sharded", ff2_dtype, dims=w2_dims) - self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) - - # Default activation is SILU - self.activation_type = ( - args.mlp_activation_type if hasattr(args, "mlp_activation_type") else ttnn.UnaryOpType.SILU - ) - - def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: - """ - w1 -> gate_proj - w2 -> down_proj - w3 -> up_proj - HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - """ - seq_len = x.shape[-2] - TG = self.args.is_galaxy - layer_num = max(self.layer_num, 0) # cross_block uses the configutation of the first decoder - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=layer_num, tensor=TensorGroup.ACTIVATION - ) - li_ff1_3_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args - ) - - if mode == "decode": # Sharded config - if TG: # TODO: Fix this when TG supports DRAM sharded matmuls - pc_1 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None - pc_2 = self.model_config["FF2_TG_PROGCFG"] if self.dim >= 4096 else None - pc_3 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None - else: - pc_1 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] - pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"] - pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] - else: # Update the program configs based for prefill - if seq_len >= self.args.prefill_len_cutoff: # 512 if Blackhole, 1024 if Wormhole - # Reshape input to to fit on device and parallelize computation - x = ttnn.reshape(x, [1, seq_len // self.args.prefill_len_cutoff, self.args.prefill_len_cutoff, -1]) - pc_1 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) - pc_2 = self.model_config["PREFILL_MLP_W2_PRG_CONFIG"](seq_len) - pc_3 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len) - - # In decode mode (seqlen <= 32) do DRAM sharded matmuls - # These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4 - memory_config = ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG - w1_out = ttnn.linear( - x, - self.w1, - dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, - compute_kernel_config=li_ff1_3_compute_kernel_cfg, - program_config=pc_1, - memory_config=memory_config, - ) - - w3_out = ttnn.linear( - x, - self.w3, - dtype=ttnn.bfloat8_b if TG else activation_dtype or ttnn.bfloat16, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, - compute_kernel_config=li_ff1_3_compute_kernel_cfg, - program_config=pc_3, - memory_config=memory_config, - ) - ttnn.deallocate(x) - - if TG: - # if mode == "decode" and self.dim!=8192: - # w1_out = ttnn.to_memory_config(w1_out, ttnn.DRAM_MEMORY_CONFIG) - # w3_out = ttnn.to_memory_config(w3_out, ttnn.DRAM_MEMORY_CONFIG) - if self.dim == 8192 or mode == "prefill": - input_mem_cfg = w1_out.memory_config() - w1_out = ttnn.reduce_scatter( - w1_out, - dim=3, - math_op=ttnn.ReduceType.Sum, - num_links=self.args.num_reduce_scatter_links, - cluster_axis=1, - mesh_device=self.mesh_device, - topology=ttnn.Topology.Linear, - memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, - ) - w3_out = ttnn.reduce_scatter( - w3_out, - dim=3, - math_op=ttnn.ReduceType.Sum, - num_links=1, - cluster_axis=1, - mesh_device=self.mesh_device, - topology=ttnn.Topology.Linear, - memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None, - ) - else: - w1_out = tt_all_reduce( - w1_out, - self.mesh_device, - cluster_axis=1, - num_all_gather_links=2, - sharded=True if mode == "decode" else False, - topology=self.args.ccl_topology(), - memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, - ) - w3_out = tt_all_reduce( - w3_out, - self.mesh_device, - cluster_axis=1, - num_all_gather_links=2, - sharded=True if mode == "decode" else False, - topology=self.args.ccl_topology(), - memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, - ) - - w2_in = ttnn.mul( - w1_out, - w3_out, - input_tensor_a_activations=[self.activation_type], - dtype=activation_dtype or ttnn.bfloat16, - memory_config=w1_out.memory_config(), - ) - - if mode == "decode" and not TG: - # w2 may use a different core grid, this is a no-op if they already match - w2_in = ttnn.to_memory_config(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"]) - - ttnn.deallocate(w3_out) - ttnn.deallocate(w1_out) - - if TG and (self.dim == 8192 or mode == "prefill"): - w2_in = ttnn.all_gather( - w2_in, - 3, - num_links=2, - cluster_axis=1, - mesh_device=self.mesh_device, - topology=ttnn.Topology.Linear, - memory_config=input_mem_cfg, - ) - if mode == "decode": - w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG) - - li_ff2_compute_kernel_cfg = self.model_config["DECODERS_OPTIMIZATIONS"].get_math_fidelity( - decoder_id=layer_num, op=OpGroup.ACCURACY, configuration=self.args - ) - w2_out = ttnn.linear( - w2_in, - self.w2, - compute_kernel_config=li_ff2_compute_kernel_cfg, - dtype=self.args.ccl_dtype if TG else activation_dtype or ttnn.bfloat16, - program_config=pc_2, - memory_config=memory_config, - core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, - ) - ttnn.deallocate(w2_in) - # if mode == "decode" and not TG: - # w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG) - w2_out_reduced = tt_all_reduce( - w2_out, - self.mesh_device, - cluster_axis=0, - dim=0 if (TG and self.dim < 8192) else 3, - num_reduce_scatter_links=self.args.num_reduce_scatter_links, - num_all_gather_links=self.args.num_all_gather_links, - sharded=(mode == "decode"), - memory_config=( - (self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config()) - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG - ), - dtype=self.args.ccl_dtype, - use_composite=True if self.dim == 8192 else False, - topology=self.args.ccl_topology(), - ) - - # Ensure dim 0 and 1 are 1 - original_shape = w2_out_reduced.shape - w2_out_reduced = ttnn.reshape( - w2_out_reduced, (1, 1, original_shape[-4] * original_shape[-3] * original_shape[-2], original_shape[-1]) - ) - if mode == "decode": - w2_out_reduced = ttnn.to_memory_config( - w2_out_reduced, - self.model_config["SHARDED_ATTN_INPUT_MEMCFG"] if TG else self.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - - # ttnn.deallocate(w2_out) - return w2_out_reduced diff --git a/models/experimental/gemma3_1b/tt/model.py b/models/experimental/gemma3_1b/tt/model.py deleted file mode 100644 index d59adb37de3c..000000000000 --- a/models/experimental/gemma3_1b/tt/model.py +++ /dev/null @@ -1,432 +0,0 @@ -""" - -This is the end-to-end implementation of the Gemma-3-1b-it model. - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -from tqdm import tqdm -import torch - -from models.experimental.gemma3_1b.tt.rmsnorm import RMSNorm - -from models.common.lightweightmodule import LightweightModule -from models.tt_transformers.tt.embedding import Embedding -from models.tt_transformers.tt.rope import RotarySetup - -from models.experimental.gemma3_1b.tt.decoder import TransformerBlock -from models.tt_transformers.tt.distributed_norm import DistributedNorm -from models.experimental.gemma3_1b.tt.lm_head import LMHead -from models.tt_transformers.tt.model_config import TensorGroup -from models.tt_transformers.tt.common import copy_host_to_device - - -class Gemma3Transformer(LightweightModule): - def __init__( - self, - args, - dtype, - mesh_device, - state_dict, - weight_cache_path, - paged_attention_config=None, - use_paged_kv_cache=False, - ): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - assert self.vocab_size > 0 - self.n_layers = args.n_layers - self.mesh_device = mesh_device - self.dtype = dtype - self.model_config = args.get_model_config() - self.grid_size = self.args.max_grid_size - state_dict_prefix = args.get_state_dict_prefix("", None) - - self.embd = Embedding( - mesh_device=mesh_device, - args=args, - weight_cache_path=args.weight_cache_path(dtype), - state_dict=state_dict, - dtype=ttnn.bfloat16, # Row major layout requires bfloat16 - ) - - self.rope_setup = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - args.rope_theta, - args.rope_scaling_factor, - args.orig_context_len, - ) - - self.rope_setup_local = RotarySetup( - mesh_device, - args.max_batch_size, - args.head_dim, - args.max_seq_len, - 10000, # Rope theta local - None, # Rope Scaling Factor - args.orig_context_len, - ) - - self.trans_mats_dict = self.rope_setup.get_both_trans_mats() - - self.layers = [ - TransformerBlock( - args=args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - weight_cache_path=weight_cache_path, - layer_num=i, - transformation_mats=self.trans_mats_dict, - paged_attention_config=paged_attention_config, - use_paged_kv_cache=use_paged_kv_cache, - ) - for i in tqdm(range(self.n_layers)) - ] - self.norm = DistributedNorm( - RMSNorm( - device=mesh_device, - dim=args.dim, - eps=args.norm_eps, - state_dict=state_dict, - state_dict_prefix=args.get_state_dict_prefix("", None), - weight_cache_path=None if args.dummy_weights else weight_cache_path, - weight_dtype=ttnn.bfloat16, - weight_key="norm", - add_unit_offset=True, - is_distributed=self.args.is_distributed_norm, - sharded_program_config=self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"], - sharded_output_config=self.model_config["LM_HEAD_INPUT_MEMCFG"], - ccl_topology=self.args.ccl_topology(), - ), - args, - args.is_galaxy, - ) - - self.lm_head = LMHead( - args=args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - state_dict_prefix=state_dict_prefix, - weight_cache_path=weight_cache_path, - max_columns_per_device=self.args.max_columns_per_device_lm_head, - ) - - self.embed_scale = args.dim**0.5 - - def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - TODO: Debate whether this function is responsible for padding - """ - tokens = tokens.reshape(1, 1, 1, -1) - S = tokens.shape[-1] - tokens = ttnn.from_torch( - tokens, - device=self.mesh_device, - dtype=ttnn.uint32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - tokens_embd = self.embd(tokens) - tokens_embd = ttnn.multiply( - tokens_embd, self.embed_scale - ) # TODO In UT, Without Multiply we got passing with better PCC, Lets debug this in pipeline - - tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) - - # Slice the rot mats to the prefill seqlen - assert ( - self.rope_setup.cos_matrix.shape[2] >= start_pos + S - ), f"Padded prefill end idx {start_pos + S} exceeds max seq len {self.rope_setup.cos_matrix.shape[2]}" - - tt_rot_mats_prefill_global = [ - self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], - ] - - tt_rot_mats_prefill_local = [ - self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], - self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], - ] - - if page_table is not None: - tt_page_table = ttnn.from_torch( - page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_page_table = None - - if chunk_page_table is not None: - tt_chunk_page_table = ttnn.from_torch( - chunk_page_table, - device=self.mesh_device, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - else: - tt_chunk_page_table = None - - return tokens_embd, [tt_rot_mats_prefill_global, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table - - def prepare_inputs_decode(self, *inputs): - """ - Inputs are torch tensors or python types. This function returns ttnn - tensors on device. - Its implementation can take advantage of a few other functions which the - model must implement. - """ - host_inputs = self.prepare_decode_inputs_host(*inputs) - device_inputs = copy_host_to_device(host_inputs, mesh_device=self.mesh_device) # Helper function - transformed_device_inputs = self.transform_decode_inputs_device(*device_inputs) - return transformed_device_inputs - - def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): - """ - Inputs are torch tensors or python types. Outputs are ttnn tensors on host. - NOTE: Tokens and current_pos are padded to batch - """ - B = tokens.shape[0] - assert current_pos.shape[0] == B, "Batch size mismatch" - assert B == self.args.max_batch_size, "Batch size must be equal to max_batch_size" - - # Necessary padding to be full tile sized when on device - tokens = torch.nn.functional.pad(tokens.view(-1), (0, 32 - len(tokens)), "constant", 0) - tokens = ttnn.from_torch( - tokens, - device=None, - dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - ) - tokens = ttnn.unsqueeze_to_4D(tokens) - - rot_current_pos = torch.maximum( - current_pos, torch.tensor(0, dtype=torch.int64) - ) # Ensure position indices are non-negative - rope_idxs = self.rope_setup.get_rot_idxs(rot_current_pos, on_host=True) - current_pos_tt = ttnn.from_torch( - current_pos, - device=None, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(None, 0) if (self.args.is_galaxy and B > 1) else (None, None), - mesh_shape=self.args.cluster_shape, - ), - ) - - if page_table is not None: - page_table = ttnn.from_torch( - page_table, - device=None, - dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - self.mesh_device, - dims=(None, -2) if (self.args.is_galaxy and B > 1) else (None, None), - mesh_shape=self.args.cluster_shape, - ), - ) - return tokens, current_pos_tt, rope_idxs, page_table - - def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_table=None): - """ - Inputs are ttnn tensors on device. This function applies any on-device - transformations which should happen before forward decode. - For example: tilize, reshape, shard. - Return transformed device tensors - - Get rope sin/cos - Embed tokens - """ - tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) - tt_rot_mats_local = self.rope_setup_local.get_rot_mats(rope_idxs) - - tt_tokens = self.embd(tokens) - tt_tokens = ttnn.multiply(tt_tokens, self.embed_scale) - - tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) - tt_tokens = ttnn.to_memory_config( - tt_tokens, - self.args.model_config["DECODE_RESIDUAL_MEMCFG"], - ) - return tt_tokens, current_pos, [tt_rot_mats, tt_rot_mats_local], page_table - - def process_output_prefill(self, tt_out, last_token_idx): - """ - Input is ttnn device tensor of logits. Output is torch logits tensor. - NOTE: In this model, prefill always uses get_last_token - """ - logits = ttnn.to_torch( - tt_out, - mesh_composer=ttnn.ConcatMesh2dToTensor( - self.mesh_device, dims=(3, 1) if self.args.is_galaxy else (1, -1), mesh_shape=self.args.cluster_shape - ), - )[0, 0, last_token_idx, : self.vocab_size] - return logits - - def process_output_decode(self, tt_out, B, S=1, is_tokens=False): - """ - Input is ttnn device tensor of logits if is_tokens=False, otherwise tokens. Output is the corresponding torch tensor. - """ - if is_tokens: - tt_out = ttnn.to_torch( - tt_out, # tt_out.cpu(blocking=True, cq_id=1), - mesh_composer=ttnn.ConcatMesh2dToTensor( - self.mesh_device, - dims=(3, 1) if self.args.is_galaxy else (1, -1), - mesh_shape=self.args.cluster_shape, - ), - )[0, 0, :B, 0] - return tt_out - - if self.args.num_devices > 1: - tt_out = ttnn.to_torch(ttnn.get_device_tensors(tt_out)[0]).float() - else: - tt_out = ttnn.to_torch(tt_out).float() - tt_out = tt_out[:, :, :B, : self.vocab_size].view(B, S, -1) - return tt_out - - def ttnn_prefill_forward( - self, - x, - rot_mats, - user_id, - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - get_last_token=-1, - kv_cache=None, - ): - """ - This method will take device tensors and any other args to run forward. - It returns ttnn device tensors. - """ - return self.forward( - x, - current_pos=None, - rot_mats=rot_mats, - user_id=user_id, - mode="prefill", - page_table=page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - get_last_token=get_last_token, - kv_cache=kv_cache, - ) - - def ttnn_decode_forward( - self, - x, - current_pos, - rot_mats, - page_table=None, - kv_cache=None, - argmax_on_device=False, - ): - """ - This method will take device tensors and any other args to run forward. - It returns ttnn device tensors. - """ - - tt_logits = self.forward( - x, - current_pos, - rot_mats=rot_mats, - mode="decode", - page_table=page_table, - kv_cache=kv_cache, - ) - - # Gather the output across all devices and untilize the tensor (for argmax) - if self.args.num_devices > 1: - if self.args.is_galaxy: - tt_logits = ttnn.all_gather( - tt_logits, - dim=3, - num_links=2, - cluster_axis=0, - mesh_device=self.mesh_device, - topology=self.args.ccl_topology(), - ) - else: - tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, topology=self.args.ccl_topology()) - tt_logits = ttnn.untilize(tt_logits, use_multicore=True) - - if argmax_on_device: - tt_logits = ttnn.argmax(tt_logits, dim=3, keepdim=True, use_multicore=True) - else: - # Send output logits to DRAM so L1 is not reserved for ttnn tracing and can be used by subsequent operations - if not self.args.is_galaxy: - tt_logits = ttnn.to_memory_config(tt_logits, ttnn.DRAM_MEMORY_CONFIG) - - return tt_logits - - def forward( - self, - x: ttnn.Tensor, - current_pos, - rot_mats=None, - user_id=0, - mode="decode", - page_table=None, - chunk_page_table=None, - chunk_start_idx=None, - get_last_token=-1, - kv_cache=None, - ): - for i, layer in enumerate(self.layers): - # No-op if callers already provide the right memory config - activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( - decoder_id=i, tensor=TensorGroup.ACTIVATION - ) - if mode == "decode" and not self.args.is_galaxy: - x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"], activation_dtype) - elif activation_dtype is not None and x.dtype != activation_dtype: - x = ttnn.typecast(x, activation_dtype) - - x = layer( - x, - current_pos, - rot_mats, - user_id, - mode, - page_table, - chunk_page_table=chunk_page_table, - chunk_start_idx=chunk_start_idx, - kv_cache=kv_cache[i] if kv_cache is not None else None, - ) - - if mode == "prefill" and get_last_token == -1: - return x - - # Slicing the tensor to the nearest ceiling/floor multiples of 32 for the prefill_len, to get the last token - if get_last_token != -1: - x = ttnn.slice(x, (0, 0, get_last_token, 0), (1, 1, get_last_token + 32, x.shape[-1])) - - # Output norm - x = self.norm(x, mode=mode) - - if mode == "prefill" and self.model_config["LM_HEAD_INPUT_MEMCFG"].is_sharded(): - x = ttnn.interleaved_to_sharded(x, self.model_config["LM_HEAD_INPUT_MEMCFG"]) - - x = self.lm_head(x) - - if mode == "prefill": - x = ttnn.to_layout(x, layout=ttnn.ROW_MAJOR_LAYOUT) - x = ttnn.to_memory_config(x, memory_config=ttnn.DRAM_MEMORY_CONFIG) - return x diff --git a/models/experimental/gemma3_1b/tt/rmsnorm.py b/models/experimental/gemma3_1b/tt/rmsnorm.py deleted file mode 100644 index 2c3f9dabce6d..000000000000 --- a/models/experimental/gemma3_1b/tt/rmsnorm.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -This is the modified version of the RMSNorm for Gemma-3-1b-it model. - -We have modified the RMSNorm implementation equivalent to RMSNorm in Gemma-3-1b-it. -We have handled the unit offset addition in the RMSNorm implementation directly into the TTNN Weights - -""" - -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -from models.common.lightweightmodule import LightweightModule - -TILE = 32 -SHARD_HEIGHT = TILE # Current ttnn.rms_norm implementation requires shard height to be a single tile - - -class RMSNorm(LightweightModule): - """ - RMSNorm supporting replication over a MeshDevice and sharding within devices. - - This class implements a Root Mean Square Normalization (RMSNorm) that can be - distributed across multiple devices and cores. If the `device` parameter is a - MeshDevice, the weights and computations are replicated across all devices in - the mesh. Expects an interleaved input tensor, can optionally output a sharded tensor. - - Args: - device: The device or MeshDevice on which to perform the computations. - state_dict: The state dictionary containing the model parameters. - dim: Input dimension (e.g. model hidden dimension size). - layer_num: The layer number to determine the weight key in the state dictionary. - weight_key: The key for retrieving the weight from the state dictionary. - weight_cache_path: Optional path for caching the tilized weights. - weight_memory_config: Configuration for the weight memory, default is DRAM_MEMORY_CONFIG. - weight_dtype: The data type for the tensors, bfp8_b hits >0.999 PCC in the models we tested. - model_config: Optional configuration dictionary for the model. - eps (float): Small value to avoid division by zero in normalization, default is 1e-05. - - If model_config is provided, it must specify SHARDED_NORM_INPUT_MEMCFG, SHARDED_NORM_PRGM_CFG - and SHARDED_NORM_OUTPUT_MEMCFG. If not provided, default configurations will be generated. - """ - - def __init__( - self, - device, - dim, - state_dict, - weight_key, - layer_num=None, - state_dict_prefix=None, - weight_cache_path=None, - weight_memory_config=ttnn.DRAM_MEMORY_CONFIG, - weight_dtype=ttnn.bfloat16, - is_distributed=None, - eps: float = 1e-06, - add_unit_offset=True, - sharded_program_config=None, - sharded_output_config=None, - output_mem_config=None, - ccl_topology=ttnn.Topology.Ring, - ): - super().__init__() - self.eps = eps - self.is_distributed = is_distributed - self.ccl_topology = ccl_topology - - if state_dict_prefix: - weight_name = f"{state_dict_prefix}{weight_key}.weight" - else: - if layer_num is None: - weight_name = f"{weight_key}.weight" - else: - weight_name = f"layers.{layer_num}.{weight_key}.weight" - - torch_weight = ( - state_dict[weight_name].unsqueeze(0).view(1, 1, dim).reshape([1, 1, dim // SHARD_HEIGHT, SHARD_HEIGHT]) - ) - - # # Add offset before caching - cache_name = None if weight_cache_path is None else weight_cache_path / weight_name - - # Compatibility with models that don't use mesh devices (e.g. single-chip Mistral-7b) - is_mesh_device = device.__class__.__name__ == "MeshDevice" - - self.weight = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=weight_memory_config, - cache_file_name=cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, - ) - if add_unit_offset: - self.weight = ttnn.add(self.weight, 1.0) - - if self.is_distributed: - self.weight_distributed = ttnn.as_tensor( - torch_weight, - device=device, - dtype=weight_dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=weight_memory_config, - cache_file_name=cache_name, - mesh_mapper=ttnn.ShardTensor2dMesh(device, dims=(None, 2), mesh_shape=list(device.shape)) - if is_mesh_device - else None, - ) - if add_unit_offset: - self.weight_distributed = ttnn.add(self.weight_distributed, 1.0) # Add offset to distributed weight - - self.sharded_output_config = sharded_output_config - self.sharded_program_config = sharded_program_config - self.output_mem_config = output_mem_config - - self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi2, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=True, - ) - - def forward(self, x: ttnn.Tensor, mode, in_sharded=False, out_sharded=False) -> ttnn.Tensor: - # If input is sharded do sharded RMSNorm and optionally return sharded output - program_config = self.sharded_program_config if in_sharded else None - memory_config = self.sharded_output_config if out_sharded else None - distributed = self.is_distributed and self.is_distributed(mode) - norm = self._distributed_rmsnorm - weight = self.weight_distributed if distributed else self.weight - - if in_sharded: - assert not distributed, "Distributed RMSNorm does not support sharded inputs" - else: - assert not out_sharded, "Non-sharded version of RMSNorm cannot output a sharded tensor" - - x = norm( - x, - epsilon=self.eps, - weight=weight, - program_config=program_config, - memory_config=memory_config, - compute_kernel_config=self.compute_kernel_config_hifi2, - ) - - if in_sharded and not out_sharded: - return ttnn.sharded_to_interleaved(x) - else: - return x - - def _distributed_rmsnorm( - self, inp, epsilon=1e-6, weight=None, program_config=None, memory_config=None, compute_kernel_config=None - ): - """ - TODO: We are using Primitive RMSNorm. - This will be replaced once the ttnn.rms_norm atol issue is fixed. - issue: https://github.com/tenstorrent/tt-metal/issues/25883 - """ - inp = ttnn.sharded_to_interleaved(inp) - - xnorm = ttnn.pow(inp, 2) - - xnorm = ttnn.mean(xnorm, dim=-1, keepdim=True) - - xnorm = ttnn.rsqrt(xnorm + epsilon) - - xnorm = ttnn.multiply(inp, xnorm) - - weight = ttnn.reshape(weight, [1, 1, 1, -1]) - - output = ttnn.multiply(xnorm, weight) - - if memory_config is not None: - output = ttnn.to_memory_config(output, memory_config) - - return output diff --git a/models/tt_transformers/tests/test_decoder.py b/models/tt_transformers/tests/test_decoder.py index bb61c937f89f..3918677d9c73 100644 --- a/models/tt_transformers/tests/test_decoder.py +++ b/models/tt_transformers/tests/test_decoder.py @@ -173,11 +173,24 @@ def test_decoder_inference( # Get cos/sin matrices for the current position of each user rot_mats = rope_setup.get_rot_mats(current_pos) + if model_args.rope_local_theta is not None: + rope_setup_local = RotarySetup( + mesh_device, + model_args.max_batch_size, + model_args.head_dim, + model_args.max_seq_len, + model_args.rope_local_theta, + model_args.rope_scaling, + ) + rot_mats_local = rope_setup_local.get_rot_mats(current_pos) + else: + rot_mats_local = None + # Run TT model tt_out = tt_model( decode_input, current_pos_tensor, - rot_mats=rot_mats, + rot_mats=[rot_mats, rot_mats_local], mode="decode", page_table=page_table_tt, ) @@ -191,7 +204,7 @@ def test_decoder_inference( freqs_cis_i = freqs_cis[current_pos[0], :].unsqueeze(0) # Reference model - ref_output = reference_model(pt_decode_input, current_pos[0], freqs_cis_i, mask=None) + ref_output = reference_model(pt_decode_input.to(dtype=torch.bfloat16), current_pos[0], freqs_cis_i, mask=None) passing, pcc_message = comp_pcc(ref_output, tt_output_torch) diff --git a/models/tt_transformers/tests/test_decoder_prefill.py b/models/tt_transformers/tests/test_decoder_prefill.py index ca63f294b2d2..62457bd79486 100644 --- a/models/tt_transformers/tests/test_decoder_prefill.py +++ b/models/tt_transformers/tests/test_decoder_prefill.py @@ -93,6 +93,17 @@ def test_decoder_inference( theta=model_args.rope_theta, rope_scaling=model_args.rope_scaling, ) + + if model_args.rope_local_theta is not None: + rot_mats_local = get_rot_mats( + head_dim=model_args.head_dim, + device=mesh_device, + seq_len=max_seq_len, + theta=model_args.rope_local_theta, + rope_scaling=None, + ) + else: + rot_mats_local = None transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, @@ -166,9 +177,13 @@ def test_decoder_inference( # Reference model attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) attn_mask_torch = torch.triu(attn_mask, diagonal=1) - ref_output = reference_model(pt_decode_input, positions[0], freqs_cis_i, mask=attn_mask_torch) + ref_output = reference_model( + pt_decode_input.to(dtype=torch.bfloat16), positions[0], freqs_cis_i, mask=attn_mask_torch + ) # Run TT model - tt_out = tt_model(decode_input, None, rot_mats, user_id=0, mode="prefill", page_table=page_table_tt) + tt_out = tt_model( + decode_input, None, [rot_mats, rot_mats_local], user_id=0, mode="prefill", page_table=page_table_tt + ) tt_out = ttnn.to_torch( tt_out, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(1, 3), mesh_shape=model_args.cluster_shape), diff --git a/models/tt_transformers/tests/test_embedding.py b/models/tt_transformers/tests/test_embedding.py index f6408a397bcd..278d42ef0894 100644 --- a/models/tt_transformers/tests/test_embedding.py +++ b/models/tt_transformers/tests/test_embedding.py @@ -58,6 +58,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc) prompts = ["Joy"] * 32 pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts]) + embed_scale = model_args.embed_scale reference_output = reference_emb(pt_input) logger.info(f"reference_output: {reference_output.shape}") @@ -68,7 +69,7 @@ def test_embedding(max_seq_len, batch_size, mesh_device, reset_seeds, ensure_gc) dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) - tt_output = tt_emb(tt_input) + tt_output = tt_emb(tt_input, embed_scale) tt_output_torch = ttnn.to_torch( tt_output, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=(0, -1), mesh_shape=model_args.cluster_shape), diff --git a/models/tt_transformers/tt/attention.py b/models/tt_transformers/tt/attention.py index 47ba6a7d95fd..1cb2e5703111 100644 --- a/models/tt_transformers/tt/attention.py +++ b/models/tt_transformers/tt/attention.py @@ -27,6 +27,7 @@ def __init__( use_paged_kv_cache=False, ): super().__init__() + self.is_sliding = configuration.is_sliding[layer_num] self.state_dict = state_dict self.mesh_device = mesh_device diff --git a/models/tt_transformers/tt/common.py b/models/tt_transformers/tt/common.py index 17655ea39770..5eebf47ce735 100644 --- a/models/tt_transformers/tt/common.py +++ b/models/tt_transformers/tt/common.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import math -import os import re from enum import Enum from typing import Optional @@ -217,8 +216,9 @@ def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): return tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) -def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): - """Llama-3.x specific scaling for rotary embeddings.""" +def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models + # Values obtained from grid search low_freq_factor = 1 high_freq_factor = 4 @@ -238,24 +238,6 @@ def compute_llama3_parameters(freqs: torch.Tensor, scale_factor: float, orig_con return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) -def compute_default_parameters(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): - """Default scaling for rotary embeddings.""" - return freqs - - -def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): - # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models - - hf_model_env = os.getenv("HF_MODEL") - - if hf_model_env == "google/gemma-3-1b-it": - freqs = compute_default_parameters(freqs, scale_factor, orig_context_len) - elif "LLAMA_DIR" in os.environ or (hf_model_env and "llama" in hf_model_env.lower()): - freqs = compute_llama3_parameters(freqs, scale_factor, orig_context_len) - - return freqs - - def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -603,11 +585,7 @@ def create_tt_model( state_dict=None, num_layers=None, ): - if "HF_MODEL" in os.environ and "gemma-3" in os.environ["HF_MODEL"].lower(): - from models.experimental.gemma3_1b.tt.model import Gemma3Transformer as Transformer - else: - from models.tt_transformers.tt.model import Transformer - + from models.tt_transformers.tt.model import Transformer from models.tt_transformers.tt.model_config import ModelArgs tt_model_args = ModelArgs( diff --git a/models/tt_transformers/tt/decoder.py b/models/tt_transformers/tt/decoder.py index 24e95a709b8a..d719a4ee5b59 100644 --- a/models/tt_transformers/tt/decoder.py +++ b/models/tt_transformers/tt/decoder.py @@ -102,6 +102,53 @@ def __init__( args, TG=args.is_galaxy, ) + if f"layers.{layer_num}.pre_feedforward_layernorm.weight" in self.state_dict: + self.pre_ff_norm = DistributedNorm( # pre_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + state_dict=state_dict, + add_unit_offset=self.args.rms_norm_add_unit_offset, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="pre_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + else: + # If pre_feedforward_layernorm is not in state_dict, we do not use it + self.pre_ff_norm = None + + if f"layers.{layer_num}.post_feedforward_layernorm.weight" in self.state_dict: + self.post_ff_norm = DistributedNorm( # post_feedforward_layernorm + RMSNorm( + device=mesh_device, + dim=args.dim, + eps=args.norm_eps, + add_unit_offset=self.args.rms_norm_add_unit_offset, + state_dict=state_dict, + state_dict_prefix=args.get_state_dict_prefix("", layer_num), + weight_cache_path=None if args.dummy_weights else weight_cache_path, + weight_dtype=ttnn.bfloat16, + weight_key="post_feedforward_layernorm", + is_distributed=self.args.is_distributed_norm, + sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], + sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), + ), + args, + TG=args.is_galaxy, + ) + else: + # If post_feedforward_layernorm is not in state_dict, we do not use it + self.post_ff_norm = None def forward( self, @@ -116,6 +163,7 @@ def forward( kv_cache=None, ) -> ttnn.Tensor: TG = self.args.is_galaxy + residual = x # x is fractured across devices and interleaved in DRAM (for prefill) and sharded in L1 (for decode) skip_mem_cfg = self.model_config["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG assert ( @@ -124,10 +172,15 @@ def forward( # Norms take fractured inputs and output replicated across devices attn_in = self.attention_norm(x, mode) # Attention takes replicated inputs and produces fractured outputs + if self.attention.is_sliding: + position_embeddings = rot_mats[1] + else: + position_embeddings = rot_mats[0] + attn_out = self.attention.forward( attn_in, current_pos, - rot_mats, + position_embeddings, user_id, mode, page_table=page_table, @@ -135,25 +188,37 @@ def forward( chunk_start_idx=chunk_start_idx, kv_cache=kv_cache, ) - # Here x and attn_out are both fractured across devices - h = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None) - ttnn.deallocate(attn_out) + if self.pre_ff_norm == None: + attn_out = ttnn.add(x, attn_out, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16 if TG else None) + + residual = attn_out + + hidden_states = self.ff_norm(attn_out, mode) + if self.pre_ff_norm is not None: + hidden_states = ttnn.add(hidden_states, residual, memory_config=skip_mem_cfg, dtype=ttnn.bfloat16) + + residual = hidden_states + + hidden_states = self.pre_ff_norm(hidden_states, mode) + if mode == "prefill": x.deallocate(True) - # Norms take fractured inputs and output replicated across devices - ff_in = self.ff_norm(h, mode) + # ttnn.deallocate(attn_out) + if TG and mode == "decode": - ff_in = ttnn.to_memory_config(ff_in, memory_config=self.model_config["MLP_ACT_MEMCFG"]) + hidden_states = ttnn.to_memory_config(hidden_states, memory_config=self.model_config["MLP_ACT_MEMCFG"]) # MLP takes replicated inputs and produces fractured outputs - ff_out = self.feed_forward.forward(ff_in, mode) - # ff_out and h are both fractured across devices + hidden_states = self.feed_forward.forward(hidden_states, mode) activation_dtype = self.model_config["DECODERS_OPTIMIZATIONS"].get_tensor_dtype( decoder_id=self.layer_num, tensor=TensorGroup.ACTIVATION ) + if self.post_ff_norm is not None: + hidden_states = self.post_ff_norm(hidden_states, mode) + out = ttnn.add( - h, - ff_out, + residual, + hidden_states, memory_config=skip_mem_cfg, dtype=self.args.ccl_dtype if TG and not self.args.is_distributed_norm(mode) diff --git a/models/tt_transformers/tt/embedding.py b/models/tt_transformers/tt/embedding.py index c1420ad22f68..344392d8237e 100644 --- a/models/tt_transformers/tt/embedding.py +++ b/models/tt_transformers/tt/embedding.py @@ -33,6 +33,7 @@ def __init__( cache_file_name=cache_name, ) - def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: + def forward(self, x: ttnn.Tensor, embed_scale: int = 1.0) -> ttnn.Tensor: x = ttnn.embedding(x, self.weights, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + x = ttnn.multiply(x, embed_scale) return x diff --git a/models/tt_transformers/tt/lm_head.py b/models/tt_transformers/tt/lm_head.py index 3be020957904..c540343a4a2c 100644 --- a/models/tt_transformers/tt/lm_head.py +++ b/models/tt_transformers/tt/lm_head.py @@ -31,6 +31,7 @@ def __init__( self.num_devices = args.num_devices size_per_device = self.vocab_size // self.num_devices + self.model_config = args.get_model_config() if args.is_galaxy: size_per_device = self.padded_vocab_size // self.num_devices @@ -138,12 +139,14 @@ def forward(self, x: ttnn.Tensor): compute_kernel_config=self.compute_kernel_config, program_config=pc, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, + dtype=self.args.lm_head_dtype or ttnn.bfloat8_b, + ) + outputs.append( + ttnn.sharded_to_interleaved(output, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"]) ) - outputs.append(ttnn.sharded_to_interleaved(output, memory_config=ttnn.L1_MEMORY_CONFIG)) # Concatenate the outputs - output = ttnn.concat(outputs, dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.concat(outputs, dim=-1, memory_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"]) output = tt_all_reduce( output, diff --git a/models/tt_transformers/tt/mlp.py b/models/tt_transformers/tt/mlp.py index 9893ec2440e4..ec9fe66d7506 100644 --- a/models/tt_transformers/tt/mlp.py +++ b/models/tt_transformers/tt/mlp.py @@ -72,7 +72,9 @@ def __init__( self.w3 = as_sharded_tensor("w3_sharded", ff1_3_dtype, dims=w1_dims) # Default activation is SILU - self.activation_type = self.args.mlp_activation_type + self.activation_type = ( + args.mlp_activation_type if hasattr(args, "mlp_activation_type") else ttnn.UnaryOpType.SILU + ) def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: """ diff --git a/models/tt_transformers/tt/model.py b/models/tt_transformers/tt/model.py index 591c915085e6..8270ced9e033 100644 --- a/models/tt_transformers/tt/model.py +++ b/models/tt_transformers/tt/model.py @@ -58,6 +58,18 @@ def __init__( rope_theta=args.rope_theta, rope_scaling=args.rope_scaling, ) + if args.rope_local_theta is not None: + self.rope_setup_local = ActualRopeSetupClass( + device=mesh_device, + batch_size=args.max_batch_size, + head_dim=args.head_dim, + max_seq_len=args.max_seq_len, + rope_theta=args.rope_local_theta, + rope_scaling=None, + ) + else: + self.rope_setup_local = None + self.trans_mats_dict = self.rope_setup.get_both_trans_mats() self.layers = [ @@ -105,6 +117,8 @@ def __init__( max_columns_per_device=self.args.max_columns_per_device_lm_head, ) + self.embed_scale = args.embed_scale + def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_page_table=None): """ Inputs are torch tensors or python types. This function returns ttnn @@ -122,7 +136,8 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) - tokens_embd = self.embd(tokens) + tokens_embd = self.embd(tokens, self.embed_scale) + tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) # Slice the rot mats to the prefill seqlen @@ -133,6 +148,13 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag self.rope_setup.cos_matrix[:, :, start_pos : start_pos + S, :], self.rope_setup.sin_matrix[:, :, start_pos : start_pos + S, :], ] + if self.rope_setup_local is not None: + tt_rot_mats_prefill_local = [ + self.rope_setup_local.cos_matrix[:, :, start_pos : start_pos + S, :], + self.rope_setup_local.sin_matrix[:, :, start_pos : start_pos + S, :], + ] + else: + tt_rot_mats_prefill_local = None if page_table is not None: tt_page_table = ttnn.from_torch( @@ -156,7 +178,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag else: tt_chunk_page_table = None - return tokens_embd, tt_rot_mats_prefill, tt_page_table, tt_chunk_page_table + return tokens_embd, [tt_rot_mats_prefill, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table def prepare_inputs_decode(self, *inputs): """ @@ -228,13 +250,18 @@ def transform_decode_inputs_device(self, tokens, current_pos, rope_idxs, page_ta Embed tokens """ tt_rot_mats = self.rope_setup.get_rot_mats(rope_idxs) - tt_tokens = self.embd(tokens) + if self.rope_setup_local is not None: + tt_rot_mats_local = self.rope_setup_local.get_rot_mats(rope_idxs) + else: + tt_rot_mats_local = None + tt_tokens = self.embd(tokens, self.embed_scale) + tt_tokens = ttnn.unsqueeze_to_4D(tt_tokens) tt_tokens = ttnn.to_memory_config( tt_tokens, self.args.model_config["DECODE_RESIDUAL_MEMCFG"], ) - return tt_tokens, current_pos, tt_rot_mats, page_table + return tt_tokens, current_pos, [tt_rot_mats, tt_rot_mats_local], page_table def concat_device_output(self, tt_out): """ diff --git a/models/tt_transformers/tt/model_config.py b/models/tt_transformers/tt/model_config.py index 659f729f28c6..e5ca6c1da460 100644 --- a/models/tt_transformers/tt/model_config.py +++ b/models/tt_transformers/tt/model_config.py @@ -99,7 +99,11 @@ def accuracy(cls, model_name): } ) else: - if base_model_name.startswith("Llama-3") or base_model_name.startswith("Mistral-7B"): + if ( + base_model_name.startswith("Llama-3") + or base_model_name.startswith("Mistral-7B") + or base_model_name.startswith("gemma-3-1b") + ): logger.info( f"Llama 3 and Mistral 7B models test insensitive to attention precision, using BFP8 attention and kv-cache with FP16 MLP accumulation even in accuracy mode" ) @@ -143,7 +147,7 @@ def performance(cls, model_name): All models use bfp4 in FF1 and FF3 MLPs in this configuration """ base_model_name = get_base_model_name(model_name) - if base_model_name == "Qwen2.5-7B": + if base_model_name in ["Qwen2.5-7B", "gemma-3-1b"]: logger.info( f"Model {model_name} is degraded under standard high-performance settings, using BF16 attention and BFP8 MLP" ) @@ -235,7 +239,7 @@ def _default_settings(self): # Attention TensorGroup.WQKV: PrecisionSetting.BFP8, TensorGroup.WO: PrecisionSetting.BFP8, - TensorGroup.KV_CACHE: PrecisionSetting.BF16, # Upgraded from BFP8 to prevent accumulation errors + TensorGroup.KV_CACHE: PrecisionSetting.BFP8, # Activation across whole model TensorGroup.ACTIVATION: None, # this signals that original dtype should be used }, @@ -245,7 +249,7 @@ def _default_settings(self): OpGroup.LI_FF2: MathFidelitySetting.HIFI2_FP16, # Attention operators -- linear and scaled_dot_product_attention, in decode and prefill modes OpGroup.LI_QKV_DECODE: MathFidelitySetting.HIFI2, - OpGroup.SDPA_DECODE: MathFidelitySetting.HIFI4, # Upgraded from HIFI2 for better precision + OpGroup.SDPA_DECODE: MathFidelitySetting.HIFI2, OpGroup.LI_O_DECODE: MathFidelitySetting.HIFI2, OpGroup.LI_QKV_PREFILL: MathFidelitySetting.HIFI2, OpGroup.SDPA_PREFILL: MathFidelitySetting.HIFI4, @@ -1261,6 +1265,10 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): ), ) + self.model_config["LM_HEAD_OUTPUT_MEMCFG"] = ( + ttnn.DRAM_MEMORY_CONFIG if self.model_name == "gemma-3-1b-it" else ttnn.L1_MEMORY_CONFIG + ) + self.lm_head_dtype = ttnn.bfloat16 if self.model_name == "gemma-3-1b-it" else None self.set_tg_attention_config() self.is_multichip = self.num_devices > 1 @@ -1404,16 +1412,15 @@ def _get_hidden_activation_type(self, config): def _set_model_specific_params(self): # Gemma3 specific params self.rms_norm_add_unit_offset = "gemma-3" in self.base_model_name.lower() + self.embed_scale = 1.0 if not "gemma-3" in self.base_model_name.lower() else self.dim ** 0.5 def _set_params_from_dict(self, config, is_hf=False): # Try to get text_config, if it doesn't exist everything is text config eos_token_id = config.get("eos_token_id", None) - self.eos_token_id = ( - None if isinstance(eos_token_id, int) else eos_token_id - ) # Gemma like models can have a list of eos token ids + self.eos_token_id = None if isinstance(eos_token_id, int) else eos_token_id - self.sliding_window_pattern = config.get("sliding_window_pattern", 1) + sliding_window_pattern = config.get("sliding_window_pattern", None) text_config = config.get("text_config", config) @@ -1427,6 +1434,11 @@ def _set_params_from_dict(self, config, is_hf=False): self.vocab_size = text_config["vocab_size"] self.padded_vocab_size = 128 * 1024 if self.is_galaxy else None self.head_dim = text_config.get("head_dim", self.dim // self.n_heads) or self.dim // self.n_heads + self.rope_local_theta = text_config.get("rope_local_base_freq", None) + self.is_sliding = [ + False if sliding_window_pattern is None else bool((layer_num + 1) % sliding_window_pattern) + for layer_num in range(self.n_layers) + ] if is_hf: self.max_context_len = text_config.get("max_position_embeddings") else: @@ -2249,7 +2261,8 @@ def reference_decoder(self): rotary_emb = model.model.rotary_emb if "gemma-3" in self.model_name: - wrapper = HfGemmaDecoderWrapper(layer, self.head_dim, rotary_emb) + rotary_emb_local = model.model.rotary_emb_local + wrapper = HfGemmaDecoderWrapper(layer, self.head_dim, rotary_emb, rotary_emb_local) else: wrapper = HfDecoderWrapper(layer, self.head_dim, rotary_emb) @@ -2268,7 +2281,6 @@ def reference_attention(self): "MistralAttention", "Gemma3Attention", ) - wrapper = HfAttentionWrapper( layer, self.head_dim, model.model.rotary_emb if use_position_embeddings else None ) @@ -2461,30 +2473,27 @@ def load_state_dict(self, state_dict): class HfGemmaDecoderWrapper: - def __init__(self, decoder, head_dim, rotary_emb): + def __init__(self, decoder, head_dim, rotary_emb, rotary_emb_local): from transformers import DynamicCache self.decoder = decoder self.head_dim = head_dim self.rotary_emb = rotary_emb + self.rotary_emb_local = rotary_emb_local self.past_key_values = DynamicCache() def forward(self, x, start_pos, freqs_cis_i, mask=None): position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) - # TODO: Generalize for other HF models - model_name_env = os.getenv("HF_MODEL") - if model_name_env is not None and "mistral" in model_name_env.lower(): - position_embeddings = self.rotary_emb(x, x.shape[1]) - else: - position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_global = self.rotary_emb(x, position_ids) + position_embeddings_local = self.rotary_emb_local(x, position_ids) if mask is not None: while len(mask.shape) < 4: mask = mask.unsqueeze(0) result = self.decoder.forward( x, - position_embeddings_global=position_embeddings, - position_embeddings_local=position_embeddings, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, past_key_value=self.past_key_values, use_cache=True, position_ids=position_ids,