Skip to content

Commit

Permalink
fix reward
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Aug 4, 2016
1 parent 1afea7e commit fa73097
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions BinaryClassReward.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function BinaryClassReward:updateOutput(input, target)
end

-- reward = scale when correctly classified
self._reward = input.new()
self._reward = self._reward or input.new()
self._reward:eq(self._binary, target)
self.reward = self.reward or input.new()
self.reward:resize(self._reward:size(1)):copy(self._reward)
Expand Down Expand Up @@ -64,11 +64,9 @@ function BinaryClassReward:updateGradInput(inputTable, target)

-- zero gradInput (this criterion has no gradInput for class pred)
self.gradInput[1]:resizeAs(input):zero()
self.gradInput[1] = self.gradInput[1]

-- learn the baseline reward
self.gradInput[2] = self.criterion:backward(baseline, self.reward)
self.gradInput[2] = self.gradInput[2]

return self.gradInput
end
Expand Down
2 changes: 1 addition & 1 deletion VRClassReward.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function VRClassReward:updateOutput(input, target)
end

-- reward = scale when correctly classified
self._reward = self._maxIdx.new()
self._reward = self._reward or self._maxIdx.new()
self._reward:eq(self._maxIdx, target)
self.reward = self.reward or input.new()
self.reward:resize(self._reward:size(1)):copy(self._reward)
Expand Down

0 comments on commit fa73097

Please sign in to comment.