Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SLM] GPTJ Multi-GPU support #3070

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions python/mlc_llm/model/gpt_j/gpt_j_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold

Expand Down Expand Up @@ -57,6 +58,9 @@ def __post_init__(self):
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
if self.head_dim == 0:
self.head_dim = self.n_embd // self.n_head
assert self.head_dim * self.n_head == self.n_embd
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %d",
Expand All @@ -72,7 +76,6 @@ def __post_init__(self):
min(self.context_window_size, 8192),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)
assert self.tensor_parallel_shards == 1, "GPTJ currently does not support sharding."


# pylint: disable=invalid-name,missing-docstring
Expand All @@ -82,7 +85,7 @@ class GPTJAttention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: GPTJConfig):
self.embed_dim = config.n_embd
self.num_heads = config.n_head // config.tensor_parallel_shards
self.head_dim = self.embed_dim // self.num_heads
self.head_dim = config.head_dim
self.max_position_embeddings = config.context_window_size
self.rope_theta = 10000
self.rotary_dim = config.rotary_dim
Expand Down Expand Up @@ -140,14 +143,41 @@ def __init__(self, config: GPTJConfig):
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(config)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.attn.num_heads * hd
k = self.attn.num_heads * hd
v = self.attn.num_heads * hd
_set(
self.attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(self.attn.out_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.fc_in.weight,
tp.ShardSingleDim("_shard_c_fc_weight", dim=0),
)
_set(self.mlp.fc_out.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(hidden_states, paged_kv_cache, layer_id)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_output + feed_forward_hidden_states + residual
hidden_states = self._apply_residual(attn_output + feed_forward_hidden_states, residual)
return hidden_states

def _apply_residual(self, out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out, "sum") + residual
return out + residual


class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/model/olmo/olmo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __post_init__(self): # pylint: disable=too-many-branches
raise ValueError(f"'clip_qkv'({self.clip_qkv}) should be non-negative")


class OLMoEebedding(nn.Embedding):
class OLMoEmbedding(nn.Embedding):
"""The embedding module that can be shared with the final lm_head. From Qwen2Embedding."""

def lm_head_forward(self, x: nn.Tensor):
Expand Down Expand Up @@ -248,7 +248,7 @@ def forward( # pylint: disable=missing-function-docstring
class OLMoModel(nn.Module): # pylint: disable=missing-class-docstring
def __init__(self, config: OLMoConfig):
assert config.hidden_size % config.num_attention_heads == 0
self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size)
self.embed_tokens = OLMoEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)]
)
Expand Down
Loading