Skip to content

Replace scipy's truncnorm.rvs with nn.init.trunc_normal_ to make model load faster#178

Open
rizsp wants to merge 2 commits intoaqlaboratory:mainfrom
rizsp:feature/truncated-init-on-gpu
Open

Replace scipy's truncnorm.rvs with nn.init.trunc_normal_ to make model load faster#178
rizsp wants to merge 2 commits intoaqlaboratory:mainfrom
rizsp:feature/truncated-init-on-gpu

Conversation

@rizsp
Copy link
Copy Markdown
Contributor

@rizsp rizsp commented Apr 13, 2026

Summary

Via line-by-line timing I noticed that the load_state_dict inside _warn_on_missing_version_tensor_in_load_statedict() takes up around 27 seconds on my machine on initial model load. Further clicking around in the profile pointed to trunc_normal_init_ which initializes truncated random variables according to kaiming or lecun using scipy.stats.rvs.

I adjusted the code so that it:

  • Caches the truncnorm.std call which always uses the same arguments (an alternative would be a Magic Number... this would also eliminate the scipy import... see code comment)
  • Replaces scipy.stats.truncnorm.rvs with nn.init.trunc_normal_ which should be much faster since it initializes directly on the GPU instead of having to generate stuff on CPU (expensive CDF calls) and then move the random number array from CPU to GPU

On my machine this reduces the model load from 27 seconds to around 3 seconds (the file loading itself only takes a second).

Changes

  • Changes trunc_normal_init() to use nn.init.trunc_normal_ instead
  • Removes unused _prod

Related Issues

Testing

I didn't spot any related test covering this function?

Other Notes

@rizsp
Copy link
Copy Markdown
Contributor Author

rizsp commented Apr 14, 2026

Sorry, I noticed some tests are failing which cover lecun_normal_init, I'll have to take a look

@jandom
Copy link
Copy Markdown
Collaborator

jandom commented Apr 14, 2026

Excellent contribution @rizsp – cherry on the cake, if you want, you could codify this with pytest-benchmark, ie test the time it takes to load a model. it'll catch any future regressions.

@rizsp
Copy link
Copy Markdown
Contributor Author

rizsp commented Apr 14, 2026

Okay, the tests were failing (failed to open libnvrtc-builtins.so.13.0) because the previous way of initializing random numbers on CPU with scipy didn't trigger the JIT. Once that part was changed to GPU, the JIT sprang in and my incomplete CUDA setup failed. Once I exported the CUTLASS and LD_LIBRARY_PATH to the right CUDA version, the tests passed again.

I can take a look into implementing a pytest benchmark.

@rizsp
Copy link
Copy Markdown
Contributor Author

rizsp commented Apr 14, 2026

@jandom I added a test to check the timing for the state dict loading and set it to fail for timings greater than 10 seconds (excluding the time to load from disk). It fails on main (takes 21s) and succeeds on this PR (takes 2s).

@jandom
Copy link
Copy Markdown
Collaborator

jandom commented Apr 14, 2026

This is excellent, thank you for resolving this so quickly (lightning speed!). My last worry would be "are we doing something to the weights by applying this procedure that we weren't doing before" – now obviously we can't do a snapshot of the entire checkpoint (2.1gb) with method A vs method B – but maybe there is something reasonable that we can do here to confirm that we're not breaking something, that the weights are indeed loaded the same just faster.

@rizsp
Copy link
Copy Markdown
Contributor Author

rizsp commented Apr 14, 2026

Since this is really just a tiny change of the code (changing the random initialization function for kaming/lecun) it should be easy to reason about.

This patch replaces scipy's truncnorm.rvs with nn.init.trunc_normal_. You actually already use methods from nn.init inside initialization.py, notably nn.init.kaiming_normal_ and nn.init.xavier_uniform_. I assume you implemented lecun_normal_init and he_normal_init from scratch using scipy because nn.init does not have truncated versions of those.

As long as one can show that nn.init.trunc_normal_ creates the same random distribution as truncnorm.rvs, the behavior should be the same? Maybe generating samples for both methods, and checking with a 2 sample Kolmogorov-Smirnov test is enough?

Here's just a sketch:

torch_tensor = torch.empty(1024, 512)
std =  math.sqrt(1.0 / 512) / 0.8796

# new nn.init method
nn.init.trunc_normal_(torch_tensor, mean=0.0, std=std, a=-2.0*std, b=2.0*std)
# old scipy method
scipy_tensor = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=torch_tensor.shape)

statistic, pvalue = ks_2samp(scipy_samples, torch_samples)

The pvalue in this snippet is 0.98 which speaks for the null hypothesis that the distributions are identical.
Or visually, as a plot:

image

@jandom
Copy link
Copy Markdown
Collaborator

jandom commented Apr 16, 2026

As long as one can show that nn.init.trunc_normal_ creates the same random distribution as truncnorm.rvs, the behavior should be the same? Maybe generating samples for both methods, and checking with a 2 sample Kolmogorov-Smirnov test is enough?

This looks sane enough to me, thanks – I've tagged Christina for review, since she's more familiar with this

@jandom jandom added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label Apr 16, 2026
@jandom
Copy link
Copy Markdown
Collaborator

jandom commented Apr 16, 2026

@rizsp this is currently failing in tests – can you reproduce locally?

@rizsp
Copy link
Copy Markdown
Contributor Author

rizsp commented Apr 16, 2026

Sorry, you're right. If I see correctly in the logs, this is because TriangleAttention is instantiated which contains a Linear layer which is initialized with lecun. Since the lecun init is now done using nn.init instead of rvs, the random numbers will be different now. The ndarrays_regression.check checks against a preinitialized array created with the previous rvs init, probably stored in test_data.

I'm not sure what the right procedure is, but if you're sure that the init replacement is fine, someone on your side would probably just need to update the array in test_data using pytest --force-regen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants