@@ -516,10 +516,11 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
516516
517517
518518class GptNeoxMLP (nn .Module ):
519- def __init__ (self , config : Config ) -> None :
519+ def __init__ (self , config : Config , intermediate_size : Optional [ int ] = None ) -> None :
520520 super ().__init__ ()
521- self .fc = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
522- self .proj = nn .Linear (config .intermediate_size , config .n_embd , bias = config .bias )
521+ self .intermediate_size = intermediate_size or config .intermediate_size
522+ self .fc = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
523+ self .proj = nn .Linear (self .intermediate_size , config .n_embd , bias = config .bias )
523524 self .config = config
524525
525526 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -529,11 +530,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
529530
530531
531532class LLaMAMLP (nn .Module ):
532- def __init__ (self , config : Config ) -> None :
533+ def __init__ (self , config : Config , intermediate_size : Optional [ int ] = None ) -> None :
533534 super ().__init__ ()
534- self .fc_1 = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
535- self .fc_2 = nn .Linear (config .n_embd , config .intermediate_size , bias = config .bias )
536- self .proj = nn .Linear (config .intermediate_size , config .n_embd , bias = config .bias )
535+ self .intermediate_size = intermediate_size or config .intermediate_size
536+ self .fc_1 = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
537+ self .fc_2 = nn .Linear (config .n_embd , self .intermediate_size , bias = config .bias )
538+ self .proj = nn .Linear (self .intermediate_size , config .n_embd , bias = config .bias )
537539 self .config = config
538540
539541 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -555,7 +557,9 @@ class LLaMAMoE(nn.Module):
555557 def __init__ (self , config : Config ) -> None :
556558 super ().__init__ ()
557559 self .gate = nn .Linear (config .n_embd , config .n_expert , bias = False )
558- self .experts = nn .ModuleList (LLaMAMLP (config ) for _ in range (config .n_expert ))
560+ self .experts = nn .ModuleList (
561+ LLaMAMLP (config , intermediate_size = config .moe_intermediate_size ) for _ in range (config .n_expert )
562+ )
559563 self .config = config
560564
561565 def forward (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments