-
Notifications
You must be signed in to change notification settings - Fork 38
[1037] Cleanup engines (torch.nn.modules and fowards functions) #1080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[1037] Cleanup engines (torch.nn.modules and fowards functions) #1080
Conversation
I understand the issue with the other engines that are in |
Before starting on the engines I wanted to start with something smaller. Therefore, in embeddings, I made the forward definition naming explicit instead of setting it via an attribute. This makes it easier to find. |
…or model accordingly
@yperugachidiaz : For the review, it will be important to have some before-PR/after-PR loss curves. We had some issues before with gradient checkpointing where the code was formally correct (as far as we could tell) but still behaved differently under training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! I like the changes. I am happy to approve this after we check that convergence is unchanged because as Christian mentioned we have had some surprising differences after functionally equivalent refactors!
|
||
return out.to(torch.float16) | ||
|
||
def forward(self, x_in, centroids): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice clean up
I think the PR has an incorrect name now: it should be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for the cleanup. Looks good but could we add a docstring to all forward functions. This will help us a lot going forward. I am happy to help with filling in some details but maybe you can start with it. Please also make sure to use our conventions (some of the docstring in engines.py use a different convention).
Description
Clean up of
engines.py
file by implementing forward pass for classes:Plus cleanup
embeddings.py
by implementing forward pass for class:Fixed sharding when modules are called in
trainer.py
.Work in progress, issue should not be closed yet.
Upcoming: new issue .... to fix the now broken checkpoints (new modules change the structure).
Issue Number
Closes #1037
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint
./scripts/actions.sh unit-test
./scripts/actions.sh integration-test
launch-slurm.py --time 60