Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss Calculation Issue with Incomplete Batches in Autoencoder Training #3

Open
Malengu opened this issue Dec 21, 2024 · 4 comments
Open

Comments

@Malengu
Copy link

Malengu commented Dec 21, 2024

Hi ExplainingAI,

Thank you for sharing these amazing projects; they have been incredibly helpful and inspiring. I noticed a potential issue related to gradient accumulation when training autoencoders. Specifically, the calculated losses at each epoch are accurate when total_batch_size = autoencoder_batch_size * autoencoder_acc_steps fits perfectly. However, in cases where total_batch_size doesn’t fit (e.g., on GPUs that can accommodate larger batch sizes) and the last batch contains fewer samples, the calculated losses seem to be incorrect. Does this implementation assume that total_batch_size always fits perfectly?

I would greatly appreciate your guidance on this matter. Thank you in advance for your time!

@Malengu Malengu closed this as completed Dec 30, 2024
@Malengu Malengu reopened this Dec 30, 2024
Repository owner deleted a comment from tusharkumar91 Dec 30, 2024
@explainingai-code
Copy link
Owner

Hello @Malengu ,
I am not sure that I understand the precise calculation error that you are mentioning here.
Assuming a batch_size=4, acc_step=2, and a dataset of size 9, could you please be a bit more specific on where the error occurs.
And does the error occur ONLY in the last batch(if it contains lesser than batch_size examples ?)

@Malengu
Copy link
Author

Malengu commented Jan 3, 2025

I am trying to point out the following:
The way l understand it is our gpu has less memory and hence these hyperparameter settings:
autoencoder_batch_size = 3
autoencoder_acc_steps = 2

Our desired batch size (total batch size) is actually:
desired_batch_size = autoencoder_batch_size * autoencoder_acc_steps = 3 * 2 = 6

Next, we move to the loss calculations, where they should match the case where we could afford the desired batch size. However, we are using autoencoder_batch_size and autoencoder_acc_steps instead.

Suppose our dataset contains these images:
X = [x1, x2, x3, ..., x13, x14, x15]
We only have a single epoch.

Case 1: [desired batch size with drop_last=False]
Let l1, l2, ..., 114, l15 be the MSE loss for each sample x1, x2, ..., x14, x15 respectively.
Let L1, L2, ..., LN be the batch losses in our single epoch.

for one epoch:
L1 = 1/6 * [l1 + l2 + ... + l6]
L2 = 1/6 * [l7 + l8 + ... + l12]
L3 = 1/3 * [l13 + l14 + l15]

for simplicity, let:
Y = [l1 + l2 + ... + l6], Z = [l7 + l8 + ... + l12], Q = [l13 + l14 + l15]
Epoch loss: 1/3 * [L1 + L2 + L3] = 1/3 * (1/6Y + 1/6Z + 1/3Q) = 1/18(Y + Z + 2Q)

Case 2: [our case (gradient accumulation)]
The same notation applies below.

Note: Here, that l am not scaling by (1/autoencoder_acc_steps); l am showing the losses you are appending in the losses lists.

for one epoch:
L1 = 1/3 * [l1 + l2 + l3]
L2 = 1/3 * [l4 + l5 + l6]
L3 = 1/3 * [l7 + l8 + l9]
L4 = 1/3 * [l10 + l11 + l12]
L5 = 1/3 * [l13 + l14 + l15]

Since you are appending each batch loss calculation without autoencoder_acc_steps scaling:
Epoch loss = 1/5 * [L1 + L2 + .. + L5] = 1/5 * 1/3 * (Y + Z + Q) = 1/15*(Y + Z + Q)

Observations

  1. The logged losses will be off Case 1 and Case 2 should match.
  2. Another issue is the scaling factor for the last batch with less samples will be 1/3 in Case 1 and 1/6 for L1, L2, ..., L5 in Case 2. This will be not an issue if Case 1 has all batches with the desired size.

@explainingai-code
Copy link
Owner

Thank you @Malengu for the detailed explanation of the issue.
Regarding the logged values being without scaling, thats actually intentional. Primarily because I wanted my per step logged loss values of different runs to be comparable to each other(in magnitude), no matter what acc_steps is used for each run. Thats why the logged value is just mean mse for the batch, whether acc_steps is 1 or > 1.
Obviously if you think you would prefer to have loss values logged with scaling, you can just multiply the logged values accordingly.

Regarding the second aspect that scaling factor of last batch would be less, yes, I think the repo assumes(unintentionally) the effective batch size to fit perfectly.
Though this should not cause any significant issues, but its incorrect to have the implementation make this assumption.
Unfortunately at the moment I dont have time to fix this in the cleanest manner and ensure that it works for all the cases. Sorry about that.

However in case you want to fix it for your training, I think the below mentioned steps should make things better. The changes are for dit training but same applies for vae training as well.

  1. Add drop_last = True in dataloader

    data_loader = DataLoader(im_dataset,
    batch_size=train_config['dit_batch_size'],
    shuffle=True)

  2. Move step_count initialization to outside the epoch loop. So that its initialized to zero only once at the start of training

    step_count = 0

  3. Remove the optimizer step and zero_grad calls at the end of each epoch

    optimizer.step()
    optimizer.zero_grad()

This basically ensures that we are always accumulating gradient of acc_steps. In the last batch of epoch, when we have accumulated gradients of steps < acc_steps, we cycle through and wait until gradients for the remaining (acc_steps-steps) have been also accumulated in the next epoch. Only after that we update the parameters and zero out the gradients.

Do let me know if you see any issues with these changes.

@Malengu
Copy link
Author

Malengu commented Jan 6, 2025

Thank you for replying. I believe the code implementation is correct, but the issue arose when the effective batch size didn’t align properly. That’s when I noticed the problem. However, I think this wouldn’t affect the optimization since we are primarily interested in the argmin of the loss function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants