Skip to content

Commit 79bda13

Browse files
committed
update comments
1 parent 2d31f33 commit 79bda13

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

sota-implementations/gail/gail.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main(cfg: "DictConfig"): # noqa: F821
4343
)
4444

4545
# Create logger
46-
exp_name = generate_exp_name("Gail-offline", cfg.logger.exp_name)
46+
exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
4747
logger = None
4848
if cfg.logger.backend:
4949
logger = get_logger(
@@ -170,11 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821
170170
with torch.no_grad():
171171
data = discriminator(data)
172172
d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)
173-
d_rewards = torch.log(data["d_logits"] + 1e-8) - torch.log(
174-
1 - data["d_logits"] + 1e-8
175-
)
176173

177-
# set d_rewards to tensordict
174+
# Set discriminator rewards to tensordict
178175
data.set(("next", "reward"), d_rewards)
179176

180177
# Get training rewards and episode lengths

0 commit comments

Comments
 (0)