diff --git a/MultiCudaTensor.lua b/MultiCudaTensor.lua index 6ff5ee3..202dc79 100644 --- a/MultiCudaTensor.lua +++ b/MultiCudaTensor.lua @@ -31,6 +31,10 @@ function MCT:size(dim) return torch.LongStorage(self._size) end +function MCT:dim() + return self:size():size() +end + function MCT:zero() for i,tensor in ipairs(self.tensors) do cutorch.withDevice(tensor:getDevice(), function() @@ -277,3 +281,22 @@ function MCT:clone() f:close() return clone end + +function MCT:uniform(lower, upper) + for i,tensor in ipairs(self.tensors) do + cutorch.withDevice(tensor:getDevice(), function() tensor:uniform(lower, upper) end) + end + return self +end + +-- math.pow(gradParam:norm(),2) +function MCT:norm(...) + assert(#{...} == 0) + local norm = 0 + for i,tensor in ipairs(self.tensors) do + norm = norm + cutorch.withDevice(tensor:getDevice(), function() return math.pow(tensor:norm(),2) end) + end + return math.sqrt(norm) +end + +assert(not MCT.storage, "If you ever define storage, you will need to modify Module.sharedClone in dpnn.Module") diff --git a/test/test.lua b/test/test.lua index 036cdeb..65d9f1c 100644 --- a/test/test.lua +++ b/test/test.lua @@ -156,6 +156,9 @@ function torchxtest.MultiCudaTensor() mytester:assert(mweight:size(1) == inputsize) mytester:assert(mweight:size(2) == outputsize) + -- test dim + mytester:assert(mweight:dim() == 2) + -- test transpose local mwt = mweight:t() mytester:assert(mwt.catdim == 1) @@ -210,6 +213,11 @@ function torchxtest.MultiCudaTensor() mytester:assertTensorEq(output, output2, 0.0001) + -- test norm + local norm = mweight:norm() + local norm2 = weight:norm() + mytester:assert(math.abs(norm - norm2) < 0.0001) + -- test zero mweight:zero() for i=1,2 do @@ -244,7 +252,13 @@ function torchxtest.MultiCudaTensor() cutorch.withDevice(1, function() mytester:assertTensorEq(mw2.tensors[1], mweight.tensors[1], 0.000001) end) cutorch.withDevice(2, function() mytester:assertTensorEq(mw2.tensors[2], mweight.tensors[2], 0.000001) end) + -- test uniform + mw2:uniform(-2, -1) + cutorch.withDevice(1, function() mytester:assert(mw2.tensors[1]:min() >= -2 and mw2.tensors[1]:max() <= -1) end) + cutorch.withDevice(2, function() mytester:assert(mw2.tensors[2]:min() >= -2 and mw2.tensors[2]:max() <= -1) end) + mytester:assert(cutorch.getDevice() == origdevice) + end function torchx.test(tests)