Skip to content

Debug Meanflow #41

@shahbuland

Description

@shahbuland

Meanflow just isn't working currently. Trainer is implemented and config/data pipeline is setup.
See branch in #34
Training run would typically be launched with python -m train --config_path configs/feats_c128/cod_128x_feats_meanflow.yml
Currently the trainer is functional and simple enough, the only place for errors would be in owl_vaes/models/meanflow.py.
We suspect the error is in the JVP but it is impossible to completely rule out the underlying model components or in trainer. It is preferred to solve this simply by changing the loss computation in the meanflow module, but further code changes are ok if they make meanflow work.

You can see an example run with the current trainer here:
https://wandb.ai/shahbuland/new_vaes/runs/pylp2x2t

Hints:

  1. I've been told that the expected behaviour for the loss would be to start at 2.0-ish and then go down gently, not what it's doing right now.
  2. When I tried disabling t =/= r, loss curve behaves as expected but samples look like noise (see https://wandb.ai/shahbuland/new_vaes/runs/mawybstp), though this might just be because r is defaulting to 0, which confuses the model, so the samples looking bad might not be a useful piece of information, but it could inform that whatever the error is, it's localized to the t =/= r branch where JVP is used. Numbers blowing up normally relates to precision errors, so my mind wanders to JVP.
  3. JVP could have a lot of room for instability or issues cause it doesn't work with flash attention. The only way I got it to work was by casting to FP32 then doing math backend for attention. I've been told that the fp32 part was unnecessary, which suggests whatever is breaking the code might be related to why FP32 was necessary. It might be worthwhile to disable the fp32 part and use that as a starting point for debugging.
  4. There are flash attention kernels for JVP backward but it could be extra work to implement it: JVP support for FlashAttention Dao-AILab/flash-attention#1672
  5. There's some code to print min/mean/max of different things during the forward pass and it would seem that target u values when r = t are about [-20,20], but with r=/=t, it's more like [-500,500], and dudt is also suspiciously large like [-1000,1000]

Note:
For dataset access send me an email and I can get you tigris access for the specific dataset. For model checkpoint, you can also get it from the same tigris under model-checkpoints/vaes/rgbdepthflow/cod_128x_feats_160k_ema.pt or model-checkpoints/rgbdepthflow/cod_128x_feats_distillenc.pt

Expected Behaviour:
Trainer should run with loss going down. Samples should start taking shape after 1k-2k steps. Feel free to run a smaller model to get it running locally, cause "loss goes down" is pretty simple to see after even 100 steps. If you are doing this for bounty, we will verify simply by running your updated code with the original config and checking samples after ~5k steps. Here is a functional diffusion decoder training run for reference: https://wandb.ai/shahbuland/new_vaes/runs/7zf5dmtx

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions