Skip to content

Commit 9948f26

Browse files
yasinzaiiCan-Zhaoericspod
authored
Fixed bug in KL_loss calculation for VAE validation step during training (#2000)
Fixes # . 1. Bug in KL_loss calculation for VAE validation step during training ### Description The KL_loss is included in the Validation loss calculation during VAE-training validation step. However, the correct z_mu, z_sigma are not passed. Either this val_epoch_losses["kl_loss"] should always be zero if that was the intension or the correct z_mu, z_sigma should be passed. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [x] Clean up long text outputs from code cells in the notebook. - [x] For security purposes, please check the contents and remove any sensitive info such as user names and private key. --------- Signed-off-by: Muhammad Nabi Yasinzai <[email protected]> Co-authored-by: Can Zhao <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 77ccd31 commit 9948f26

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

generation/maisi/maisi_train_vae_tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@
692692
},
693693
{
694694
"cell_type": "code",
695-
"execution_count": 14,
695+
"execution_count": null,
696696
"id": "4c251a32-390f-46dd-a613-75b12a7884c1",
697697
"metadata": {
698698
"scrolled": true
@@ -850,7 +850,7 @@
850850
" with torch.no_grad():\n",
851851
" with autocast(\"cuda\", enabled=args.amp):\n",
852852
" images = batch[\"image\"]\n",
853-
" reconstruction, _, _ = dynamic_infer(val_inferer, autoencoder, images)\n",
853+
" reconstruction, z_mu, z_sigma = dynamic_infer(val_inferer, autoencoder, images)\n",
854854
" reconstruction = reconstruction.to(device)\n",
855855
" val_epoch_losses[\"recons_loss\"] += intensity_loss(reconstruction, images.to(device)).item()\n",
856856
" val_epoch_losses[\"kl_loss\"] += KL_loss(z_mu, z_sigma).item()\n",

0 commit comments

Comments
 (0)