Based on other implementations out there, it seems as though it is beneficial to normalize all channels of the latent space separately. To this end, we should have some way of specifying vae channel scales in the config as tuples. Everywhere that currently accepts scalars should be modified to also accept a list of floats that are then used to multiply by channel. This would have an effect right before scaling for training purposes and when unscaling before decoding during sampling. It's probably best to make a helper function for this
@torch.no_grad()
def channel_scaling(latent, scales):
# latent is [b,n,c,h,w] or [b,c,h,w]
# scale is list[float]
...
Based on other implementations out there, it seems as though it is beneficial to normalize all channels of the latent space separately. To this end, we should have some way of specifying vae channel scales in the config as tuples. Everywhere that currently accepts scalars should be modified to also accept a list of floats that are then used to multiply by channel. This would have an effect right before scaling for training purposes and when unscaling before decoding during sampling. It's probably best to make a helper function for this