-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathroll_out.py
51 lines (45 loc) · 2 KB
/
roll_out.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn.functional as F
import copy
class Rollout(object):
def __init__(self, model, update_rate):
"""
:param model: the model to do the roll out.
:param update_rate: update the parameter rate
in other to easily sample the data a bit more random
"""
self.ori_model = model
self.rolling_model = copy.deepcopy(model)
self.update_rate = update_rate
def get_reward(self, data, num, discriminator):
"""
To get the reward of every action of the sequence making.
:param data: The input of the discriminator action. [batch_size, action_len]
:param num: The number of the sample times. (The larger the better but consuming more time and computation)
:param discriminator: The discriminator to compute the score.
:return:
"""
batch_size = data.size(0)
seq_len = data.size(1)
reward = []
for i in range(num):
for j in range(1, seq_len+1):
temp_data = self.rolling_model.partial_sample(seq_len, data[:, :j])
pred_reward = discriminator(temp_data) # tensor
pred_reward = F.softmax(pred_reward, 1)
# If the first time to get the reward.
if i == 0:
reward.append(pred_reward[:, 1].unsqueeze(1))
else:
reward[j-1] += pred_reward[:, 1].unsqueeze(1)
reward = torch.cat(reward, dim=1) / num # [ batch_size, seq_len ]
return reward
def update_param(self):
"""
update the parameter with the the update_rate percent origin model.
"""
for (name1, param1), (name2, param2) in \
zip(self.ori_model.named_parameters(), self.rolling_model.named_parameters()):
if name1 != name2:
raise ValueError("The models parameter has been change")
param1.data = self.update_rate * param1.data + (1 - self.update_rate) * param2.data