diff --git a/treemax.lua b/treemax.lua index 18fc6a6..8ca9b9f 100644 --- a/treemax.lua +++ b/treemax.lua @@ -7,11 +7,17 @@ function torch.treemax(tensor, treeSize) tmb = { mean = tensor.new(), max = tensor.new(), - idx = torch.LongTensor() + idx = torch.LongTensor(), + copy = tensor.new() } treeMaxBuffer[torch.type(tensor)] = tmb end + if not tensor:isContiguous() then + tmb.copy:resizeAs(tensor):copy(tensor) + tensor = tmb.copy + end + local lvl = tensor local maxIdx, maxVal = 1, 0 for i=1,#treeSize do