File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed
sota-implementations/gail Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def main(cfg: "DictConfig"): # noqa: F821
4343 )
4444
4545 # Create logger
46- exp_name = generate_exp_name ("Gail-offline " , cfg .logger .exp_name )
46+ exp_name = generate_exp_name ("Gail" , cfg .logger .exp_name )
4747 logger = None
4848 if cfg .logger .backend :
4949 logger = get_logger (
@@ -170,11 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821
170170 with torch .no_grad ():
171171 data = discriminator (data )
172172 d_rewards = - torch .log (1 - data ["d_logits" ] + 1e-8 )
173- d_rewards = torch .log (data ["d_logits" ] + 1e-8 ) - torch .log (
174- 1 - data ["d_logits" ] + 1e-8
175- )
176173
177- # set d_rewards to tensordict
174+ # Set discriminator rewards to tensordict
178175 data .set (("next" , "reward" ), d_rewards )
179176
180177 # Get training rewards and episode lengths
You can’t perform that action at this time.
0 commit comments