Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about backprop #20

Open
edwhu opened this issue Oct 28, 2024 · 0 comments
Open

Question about backprop #20

edwhu opened this issue Oct 28, 2024 · 0 comments

Comments

@edwhu
Copy link

edwhu commented Oct 28, 2024

Hi Nicklas,

I am looking into this and the FOWM codebase, and have a few questions.

Q1: Backprop through time?

First, in the TDMPC codebase, it seems like you detach all the zs in the latent rollout. Doesn't this stop the backprop through time? If I do Q(z_3), the loss for Q(z_3) will not affect z_2, z_1.

Q1, Q2 = self.model.Q(z, action[t])
z, reward_pred = self.model.next(z, action[t])
with torch.no_grad():
next_obs = self.aug(next_obses[t])
next_z = self.model_target.h(next_obs)
td_target = self._td_target(next_obs, reward[t])
zs.append(z.detach())

In the FOWM codebase, it stores all the predicted zs into an array without detaching, and then predicts Q values for all the predicted zs.
https://github.com/fyhMer/fowm/blob/bf7985876eb4baa827f41567cb6fb47b5da93ed4/src/algorithm/tdmpc.py#L283-L299

Q2: Why enable / disable tracking of Q gradients during update_pi?

If we are using the pi_optim to optimize, that would only update the policy parameters. So why do we need to disable / enable the Q network grads? Does that somehow make things more efficient?

def update_pi(self, zs):
"""Update policy using a sequence of latent states."""
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
# Loss is a weighted sum of Q-values
pi_loss = 0
for t,z in enumerate(zs):
a = self.model.pi(z, self.cfg.min_std)
Q = torch.min(*self.model.Q(z, a))
pi_loss += -Q.mean() * (self.cfg.rho ** t)
pi_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
self.pi_optim.step()
self.model.track_q_grad(True)
return pi_loss.item()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant