From 0dfb7566e6221f33fe1f2e0e3295a015e962878b Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 27 Feb 2024 16:26:58 +0900 Subject: [PATCH] Fix QKV weight sharding for gemma (#222) Fix gemma multi-gpu --- mlc_llm/core.py | 5 +---- mlc_llm/relax_model/commons.py | 2 +- mlc_llm/relax_model/llama.py | 18 ++++++++---------- mlc_llm/relax_model/llama_batched_vllm.py | 11 ++--------- 4 files changed, 12 insertions(+), 24 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 29e33fd6fc..662b79510f 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -646,10 +646,7 @@ def mod_transform_before_build( num_key_value_heads = config.get_num_key_value_heads() num_query_heads = config.num_attention_heads // args.num_shards hidden_size = config.hidden_size // args.num_shards - if hasattr(config, "head_dim"): - head_dim = config.head_dim - else: - head_dim = hidden_size // num_query_heads + head_dim = config.get_head_dim() # pylint: disable=no-value-for-parameter mod = fuse_split_rotary_embedding( num_query_heads, diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 01ce81297f..b6da00b0b7 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -32,7 +32,7 @@ def create_metadata_func( def _get_shard_strategies( model_config, num_shards: int, param_shape_is_already_sharded: bool ) -> Dict[str, tvm.tir.PrimFunc]: - head_dim = model_config.hidden_size // model_config.num_attention_heads + head_dim = model_config.get_head_dim() q_heads = model_config.num_attention_heads kv_heads = model_config.get_num_key_value_heads() diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 957ff192c3..ecfaa8172f 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -76,6 +76,9 @@ def get_num_key_value_heads(self): return self.num_key_value_heads + def get_head_dim(self): + return self.hidden_size // self.num_attention_heads + class MixtralConfig(LlamaConfig): num_experts_per_tok: int @@ -110,6 +113,9 @@ def __init__( super().__init__(**kwargs) self.head_dim = kwargs["head_dim"] + def get_head_dim(self): + return self.head_dim + class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): @@ -329,12 +335,7 @@ def __init__(self, config: LlamaConfig): self.hidden_size = config.hidden_size self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards self.num_query_heads = config.num_attention_heads // self.num_shards - - if hasattr(config, "head_dim"): - self.head_dim = config.head_dim - else: - self.head_dim = config.hidden_size // config.num_attention_heads - + self.head_dim = config.get_head_dim() self.position_embedding_base = config.position_embedding_base self.combine_matmul = config.combine_matmul @@ -1394,10 +1395,7 @@ def quantize(experts, relax_pname): assert relax_pname.endswith("scales") return qscale - if hasattr(config, "head_dim"): - head_dim = config.head_dim - else: - head_dim = config.hidden_size // config.num_attention_heads + head_dim = config.get_head_dim() def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # Expected to enter this function only for the combined linear matmul weights. diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index da6217a70b..330f23f00b 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -569,11 +569,7 @@ def __init__( self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ - if hasattr(config, "head_dim"): - head_dim = config.head_dim - else: - assert config.hidden_size % config.num_attention_heads == 0 - head_dim = config.hidden_size // config.num_attention_heads + head_dim = config.get_head_dim() # Set the cached sin/cos to the maximum of 2048 and max seq len. # This will be eliminated further with online rotary embedding calculation. @@ -722,10 +718,7 @@ def get_inputs( num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - if hasattr(config, "head_dim"): - head_size = config.head_dim - else: - head_size = config.hidden_size // config.num_attention_heads + head_size = config.get_head_dim() if kv_type == KVCacheType.VLLM: block_size = VllmAttention.block_size