File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed
sota-implementations/gail Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def main(cfg: "DictConfig"): # noqa: F821
43
43
)
44
44
45
45
# Create logger
46
- exp_name = generate_exp_name ("Gail-offline " , cfg .logger .exp_name )
46
+ exp_name = generate_exp_name ("Gail" , cfg .logger .exp_name )
47
47
logger = None
48
48
if cfg .logger .backend :
49
49
logger = get_logger (
@@ -170,11 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821
170
170
with torch .no_grad ():
171
171
data = discriminator (data )
172
172
d_rewards = - torch .log (1 - data ["d_logits" ] + 1e-8 )
173
- d_rewards = torch .log (data ["d_logits" ] + 1e-8 ) - torch .log (
174
- 1 - data ["d_logits" ] + 1e-8
175
- )
176
173
177
- # set d_rewards to tensordict
174
+ # Set discriminator rewards to tensordict
178
175
data .set (("next" , "reward" ), d_rewards )
179
176
180
177
# Get training rewards and episode lengths
You can’t perform that action at this time.
0 commit comments