-
Notifications
You must be signed in to change notification settings - Fork 2
Add standalone parallel CNN functions for MNIST #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torch.tensor(calc_in["grads_in"][offset:offset + grad_size].reshape(grad_shape), | ||
device=param.grad.device) | ||
) | ||
offset += grad_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If need to average above, then need to do it here also.
param.grad /= group_size # SH averaging? | ||
|
||
# SH TODO - consider inheriting dist (torch.distributed) | ||
# and overriding dist.all_reduce() to do libE stuff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.distributed is a module not a class, but could extract these lines to some provided function.
loss = nn.CrossEntropyLoss()(output, target) | ||
loss.backward() | ||
|
||
# Synchronize gradients across processes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#if do optimizer step on gen gen does
optimizer.zero_grad()
gradients.sum()
optimizer.step()
For comparison with libE.
Works in parallel on Perlmutter and uses GPUs on one node, but does not work across nodes.
Main loop