Replace scipy's truncnorm.rvs with nn.init.trunc_normal_ to make model load faster#178
Replace scipy's truncnorm.rvs with nn.init.trunc_normal_ to make model load faster#178rizsp wants to merge 2 commits intoaqlaboratory:mainfrom
Conversation
|
Sorry, I noticed some tests are failing which cover |
|
Excellent contribution @rizsp – cherry on the cake, if you want, you could codify this with pytest-benchmark, ie test the time it takes to load a model. it'll catch any future regressions. |
|
Okay, the tests were failing (failed to open libnvrtc-builtins.so.13.0) because the previous way of initializing random numbers on CPU with scipy didn't trigger the JIT. Once that part was changed to GPU, the JIT sprang in and my incomplete CUDA setup failed. Once I exported the CUTLASS and LD_LIBRARY_PATH to the right CUDA version, the tests passed again. I can take a look into implementing a pytest benchmark. |
|
@jandom I added a test to check the timing for the state dict loading and set it to fail for timings greater than 10 seconds (excluding the time to load from disk). It fails on main (takes 21s) and succeeds on this PR (takes 2s). |
|
This is excellent, thank you for resolving this so quickly (lightning speed!). My last worry would be "are we doing something to the weights by applying this procedure that we weren't doing before" – now obviously we can't do a snapshot of the entire checkpoint (2.1gb) with method A vs method B – but maybe there is something reasonable that we can do here to confirm that we're not breaking something, that the weights are indeed loaded the same just faster. |
This looks sane enough to me, thanks – I've tagged Christina for review, since she's more familiar with this |
|
@rizsp this is currently failing in tests – can you reproduce locally? |
|
Sorry, you're right. If I see correctly in the logs, this is because TriangleAttention is instantiated which contains a Linear layer which is initialized with lecun. Since the lecun init is now done using nn.init instead of rvs, the random numbers will be different now. The ndarrays_regression.check checks against a preinitialized array created with the previous rvs init, probably stored in I'm not sure what the right procedure is, but if you're sure that the init replacement is fine, someone on your side would probably just need to update the array in |

Summary
Via line-by-line timing I noticed that the
load_state_dictinside_warn_on_missing_version_tensor_in_load_statedict()takes up around 27 seconds on my machine on initial model load. Further clicking around in the profile pointed totrunc_normal_init_which initializes truncated random variables according to kaiming or lecun usingscipy.stats.rvs.I adjusted the code so that it:
truncnorm.stdcall which always uses the same arguments (an alternative would be a Magic Number... this would also eliminate the scipy import... see code comment)scipy.stats.truncnorm.rvswithnn.init.trunc_normal_which should be much faster since it initializes directly on the GPU instead of having to generate stuff on CPU (expensive CDF calls) and then move the random number array from CPU to GPUOn my machine this reduces the model load from 27 seconds to around 3 seconds (the file loading itself only takes a second).
Changes
trunc_normal_init()to usenn.init.trunc_normal_instead_prodRelated Issues
Testing
I didn't spot any related test covering this function?
Other Notes