Commit 6bc2d34
Debug Embedding Collection to check for NaNs in backward (#3519)
Summary:
Pull Request resolved: #3519
This diff adds support to Embedding classes to detect NaNs in backward. It adds the following: `DebugEmbeddingCollection`
`DebugEmbeddingCollectionClass`
Currently it checks if gradients contain a NaN during backward. Before we call .backward() upon `EmbeddingCollection`, this class will catch the issue first. It works by wrapping all the tensors (inside KeyedJaggedTensor) with an autograd function. this autograd function performs identity during forward but checks for nans during backward. The same is happening for `EmbeddingBagCollection` also.
This diff adds 3 tests alongside debug embedding classes
- `test_embedding`
- `test_embedding_bag`
- `test_model` (reference DLRM model which uses `DebugEmbeddingCollectionClass`). The test adds NaN to the logits, after which it would be caught by `DebugEmbeddingCollectionClass` before we can do backward)
Addresses the issue which was previously seen in S542457 https://docs.google.com/presentation/d/1soiz7UxALa_hsgCnOEw_OL4yg8oK4z-VNEIVMB7_v7U/edit?slide=id.g37205c3166e_1_135#slide=id.g37205c3166e_1_135
Reviewed By: jeffkbkim
Differential Revision: D86233629
fbshipit-source-id: 4f620d84c90c01c045cc4b69e1c5564ed2839ff31 parent 8f3ed1a commit 6bc2d34
File tree
2 files changed
+629
-0
lines changed- torchrec
- distributed/tests
- modules
2 files changed
+629
-0
lines changed
0 commit comments