Skip to content

Channel specific normalization #66

@shahbuland

Description

@shahbuland

Based on other implementations out there, it seems as though it is beneficial to normalize all channels of the latent space separately. To this end, we should have some way of specifying vae channel scales in the config as tuples. Everywhere that currently accepts scalars should be modified to also accept a list of floats that are then used to multiply by channel. This would have an effect right before scaling for training purposes and when unscaling before decoding during sampling. It's probably best to make a helper function for this

@torch.no_grad()
def channel_scaling(latent, scales):
 # latent is [b,n,c,h,w] or [b,c,h,w]
 # scale is list[float]
 
 ...

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions