From 1d6ca3de7840fee99b6702cd280bc4b53187d0f9 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 1 Apr 2022 19:04:50 -0700 Subject: [PATCH] backward not retain graph for last grad acc step --- src/pytti/ImageGuide.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 5aff2b9..63ed79b 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -307,7 +307,10 @@ def train( total_loss_mb /= gradient_accumulation_steps # total_loss_mb.backward() - total_loss_mb.backward(retain_graph=True) + if gradient_accumulation_steps - mb_i > 1: + total_loss_mb.backward(retain_graph=True) + else: + total_loss_mb.backward() # total_loss += total_loss_mb # this is causing it to break # total_loss = total_loss_mb