Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions models/tt_transformers/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import ttnn
from models.common.utility_functions import is_wormhole_b0, nearest_32
from models.demos.gemma3.tt.model_config import ModelArgs as Gemma3ModelArgs
from models.tt_transformers.tt.generator import Generator, create_submeshes
from models.tt_transformers.tt.model import Transformer
from models.tt_transformers.tt.model_config import DecodersPrecision, ModelArgs
Expand Down Expand Up @@ -673,3 +674,95 @@ def decode_forward(self, *args, **kwargs):

def allocate_kv_cache(self, *args, **kwargs):
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)


def initialize_vllm_text_transformer_for_gemma(
hf_config,
tt_data_parallel,
mesh_device,
max_batch_size,
max_seq_len,
n_layers=None,
dtype=ttnn.bfloat8_b,
optimizations=DecodersPrecision.performance,
):
submesh_devices = create_submeshes(mesh_device, tt_data_parallel)
# Load model args, weights
model_args = []
for submesh in submesh_devices:
model_args_i = Gemma3ModelArgs(
submesh,
instruct=(
"Instruct" in hf_config._name_or_path or "DeepSeek-R1-Distill-Llama-70B" in hf_config._name_or_path
),
max_batch_size=max_batch_size // tt_data_parallel,
optimizations=lambda model_args: optimizations(model_args.n_layers, model_args.model_name),
max_seq_len=max_seq_len,
)

assert model_args_i.model_name.replace("-", "") in hf_config._name_or_path.replace(
"-", ""
), f"The model specified in vLLM ({hf_config._name_or_path}) does not match the model name ({model_args_i.model_name}) with model weights ({model_args_i.CKPT_DIR})."
if n_layers is not None:
model_args_i.n_layers = n_layers

model_args.append(model_args_i)

state_dict = model_args[0].load_state_dict()

tt_model = []
for i, submesh in enumerate(submesh_devices):
tt_model_i = Transformer(
args=model_args[i],
mesh_device=submesh,
dtype=dtype,
state_dict=state_dict,
weight_cache_path=model_args[i].weight_cache_path(dtype),
use_paged_kv_cache=True,
)
tt_model.append(tt_model_i)

return tt_model, model_args


class Gemma3ForCausalLM(Generator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def initialize_vllm_model(
cls,
hf_config,
mesh_device,
max_batch_size,
n_layers=None,
tt_data_parallel=1,
max_seq_len=131072,
optimizations: str = "performance",
):
tt_model, model_args = initialize_vllm_text_transformer_for_gemma(
hf_config,
tt_data_parallel,
mesh_device,
max_batch_size,
max_seq_len=max_seq_len,
n_layers=n_layers,
dtype=ttnn.bfloat8_b,
optimizations=DecodersPrecision.from_string(optimizations)
if optimizations is not None
else DecodersPrecision.performance,
)
return cls(tt_model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args[0].model_cache_path

def prefill_forward(self, *args, **kwargs):
return super().prefill_forward_text(*args, **kwargs)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)

def allocate_kv_cache(self, *args, **kwargs):
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)