Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BY571 committed Jul 5, 2024
1 parent 2d31f33 commit 79bda13
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create logger
exp_name = generate_exp_name("Gail-offline", cfg.logger.exp_name)
exp_name = generate_exp_name("Gail", cfg.logger.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
Expand Down Expand Up @@ -170,11 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821
with torch.no_grad():
data = discriminator(data)
d_rewards = -torch.log(1 - data["d_logits"] + 1e-8)
d_rewards = torch.log(data["d_logits"] + 1e-8) - torch.log(
1 - data["d_logits"] + 1e-8
)

# set d_rewards to tensordict
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
Expand Down

0 comments on commit 79bda13

Please sign in to comment.