FSQ training collapse #226
Replies: 2 comments 1 reply
-
I was wondering if there was a way to improve on the straight through estimator by using the same tricks in https://github.com/necla-ml/Diff-JPEG/blob/main/diff_jpeg/rounding.py i.e. either soft rounding in the backward pass, or what they call polynomial rounding |
Beta Was this translation helpful? Give feedback.
-
Right, even with the suggested weight initialization, i find that i get pretty poor codebook utilization, despite what the paper says.
I'm training on LibriTTS dataset. Once trained, i plot a histogram of the codebook using the validation set and i get the following: ![]() With |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
When training a neural network audio codec model, e.g. soundstream, encodec, TS3-Codec, etc, with (Grouped)(Residual)FSQ, without careful weight initialization in the final projection layer of the encoder, training can completely collapse.
In all my tests, loss will stay put at some value and stay there for 500 epochs, that's even with learning rate warmup and all the standard tricks. Looking at the quantized output of FSQ all values are +/- 1, meaning the encoder outputs were too high and once tanh-ed, yielded +/- 1. In this situation the gradients go nowhere. So everything is getting pushed to the outer boundaries of the FSQ hypercube. The only way i get around that is if the projection layer in FSQ has very small weight initialization and zero initialized bias. That keeps encoder outputs small early on and allows gradients to flow.
So basically for this layer:
vector-quantize-pytorch/vector_quantize_pytorch/finite_scalar_quantization.py
Line 109 in 2b367e5
I'm having to add something like:
Has anybody observed behaviour like this?
Beta Was this translation helpful? Give feedback.
All reactions