You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Meanflow just isn't working currently. Trainer is implemented and config/data pipeline is setup.
See branch in #34
Training run would typically be launched with python -m train --config_path configs/feats_c128/cod_128x_feats_meanflow.yml
Currently the trainer is functional and simple enough, the only place for errors would be in owl_vaes/models/meanflow.py.
We suspect the error is in the JVP but it is impossible to completely rule out the underlying model components or in trainer. It is preferred to solve this simply by changing the loss computation in the meanflow module, but further code changes are ok if they make meanflow work.
I've been told that the expected behaviour for the loss would be to start at 2.0-ish and then go down gently, not what it's doing right now.
When I tried disabling t =/= r, loss curve behaves as expected but samples look like noise (see https://wandb.ai/shahbuland/new_vaes/runs/mawybstp), though this might just be because r is defaulting to 0, which confuses the model, so the samples looking bad might not be a useful piece of information, but it could inform that whatever the error is, it's localized to the t =/= r branch where JVP is used. Numbers blowing up normally relates to precision errors, so my mind wanders to JVP.
JVP could have a lot of room for instability or issues cause it doesn't work with flash attention. The only way I got it to work was by casting to FP32 then doing math backend for attention. I've been told that the fp32 part was unnecessary, which suggests whatever is breaking the code might be related to why FP32 was necessary. It might be worthwhile to disable the fp32 part and use that as a starting point for debugging.
There's some code to print min/mean/max of different things during the forward pass and it would seem that target u values when r = t are about [-20,20], but with r=/=t, it's more like [-500,500], and dudt is also suspiciously large like [-1000,1000]
Note:
For dataset access send me an email and I can get you tigris access for the specific dataset. For model checkpoint, you can also get it from the same tigris under model-checkpoints/vaes/rgbdepthflow/cod_128x_feats_160k_ema.pt or model-checkpoints/rgbdepthflow/cod_128x_feats_distillenc.pt
Expected Behaviour:
Trainer should run with loss going down. Samples should start taking shape after 1k-2k steps. Feel free to run a smaller model to get it running locally, cause "loss goes down" is pretty simple to see after even 100 steps. If you are doing this for bounty, we will verify simply by running your updated code with the original config and checking samples after ~5k steps. Here is a functional diffusion decoder training run for reference: https://wandb.ai/shahbuland/new_vaes/runs/7zf5dmtx
Meanflow just isn't working currently. Trainer is implemented and config/data pipeline is setup.
See branch in #34
Training run would typically be launched with
python -m train --config_path configs/feats_c128/cod_128x_feats_meanflow.ymlCurrently the trainer is functional and simple enough, the only place for errors would be in
owl_vaes/models/meanflow.py.We suspect the error is in the JVP but it is impossible to completely rule out the underlying model components or in trainer. It is preferred to solve this simply by changing the loss computation in the meanflow module, but further code changes are ok if they make meanflow work.
You can see an example run with the current trainer here:
https://wandb.ai/shahbuland/new_vaes/runs/pylp2x2t
Hints:
Note:
For dataset access send me an email and I can get you tigris access for the specific dataset. For model checkpoint, you can also get it from the same tigris under model-checkpoints/vaes/rgbdepthflow/cod_128x_feats_160k_ema.pt or model-checkpoints/rgbdepthflow/cod_128x_feats_distillenc.pt
Expected Behaviour:
Trainer should run with loss going down. Samples should start taking shape after 1k-2k steps. Feel free to run a smaller model to get it running locally, cause "loss goes down" is pretty simple to see after even 100 steps. If you are doing this for bounty, we will verify simply by running your updated code with the original config and checking samples after ~5k steps. Here is a functional diffusion decoder training run for reference: https://wandb.ai/shahbuland/new_vaes/runs/7zf5dmtx