diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 5aff2b9..63ed79b 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -307,7 +307,10 @@ def train( total_loss_mb /= gradient_accumulation_steps # total_loss_mb.backward() - total_loss_mb.backward(retain_graph=True) + if gradient_accumulation_steps - mb_i > 1: + total_loss_mb.backward(retain_graph=True) + else: + total_loss_mb.backward() # total_loss += total_loss_mb # this is causing it to break # total_loss = total_loss_mb