-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MoE][PoC] model code #730
base: gh/tianyu-l/24/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
multiple_of=model_args.multiple_of, | ||
ffn_dim_multiplier=model_args.ffn_dim_multiplier, | ||
) | ||
self.enable_moe = model_args.enable_moe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor nit - I think this is expressed better as self.is_dense_model and self.is_moe_model.
'Enable' is not the same thing as 'is' so I think cleaner to express it's state of being via 'is'.
In addition, I think the later checks is cleaner expressed as self.is_dense_model/layer makes for easier to read the checks vs 'not self.enable_moe', which may not later really mean it's a dense model/layer if we start enabling other arches.
@@ -321,14 +371,20 @@ def forward( | |||
|
|||
""" | |||
h = x + self.attention(self.attention_norm(x), freqs_cis) | |||
out = h + self.feed_forward(self.ffn_norm(h)) | |||
if not self.enable_moe: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrt to above, this is what I mean about implying it's dense but not really making it clear....is_dense_model is more precise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified it to"
if self.is_dense_model:
out = h + self.feed_forward(self.ffn_norm(h))
elif self.is_moe_model:
out = h + self.moe(self.ffn_norm(h))
else:
raise NotImplementedError("unknown model type")
in my copy of your PR..mostly showing how it reads with the is_blah_model approach.
return out | ||
|
||
def init_weights(self): | ||
for norm in (self.attention_norm, self.ffn_norm): | ||
norm.reset_parameters() | ||
self.attention.init_weights(self.weight_init_std) | ||
self.feed_forward.init_weights(self.weight_init_std) | ||
if not self.enable_moe: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
Stack from ghstack (oldest at bottom):
The expert-choice MoE layer is inspired by torchtune: pytorch/torchtune#1902
Not including token-choice MoE for now.