Skip to content

Commit

Permalink
[Fix][Model] Fix tensor parallelism of internlm2 model (#3058)
Browse files Browse the repository at this point in the history
This PR fixes a bug that causes incorrect output when running internlm2
with tensor parallelism.
  • Loading branch information
MasterJH5574 authored Dec 6, 2024
1 parent d23d6f5 commit 8477ac0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions python/mlc_llm/model/internlm2/internlm2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ def __post_init__(self):
logger.info(
"%s defaults to %d",
bold("prefill_chunk_size"),
min(self.context_window_size, 8192),
min(self.context_window_size, 2048),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)
self.prefill_chunk_size = min(self.context_window_size, 2048)
elif self.prefill_chunk_size > self.context_window_size:
logger.info(
"Overriding %s from %d to %d",
bold("prefill_chunk_size"),
self.prefill_chunk_size,
min(self.context_window_size, 8192),
min(self.context_window_size, 2048),
)
self.prefill_chunk_size = min(self.context_window_size, 8192)
self.prefill_chunk_size = min(self.context_window_size, 2048)


# pylint: disable=invalid-name,missing-docstring
Expand Down Expand Up @@ -178,11 +178,11 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
residual = hidden_states
hidden_states = self.attention_norm(hidden_states)
hidden_states = self.attention(hidden_states, paged_kv_cache, layer_id)
hidden_states = self._apply_residual(residual, residual=hidden_states)
hidden_states = self._apply_residual(hidden_states, residual=residual)
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
hidden_states = self.feed_forward(hidden_states)
hidden_states = self._apply_residual(residual, residual=hidden_states)
hidden_states = self._apply_residual(hidden_states, residual=residual)
return hidden_states

def _apply_residual(self, out, residual):
Expand Down

0 comments on commit 8477ac0

Please sign in to comment.