From d574a10f0adb2d7649ec6449364ceaa60de3a395 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Thu, 30 Apr 2020 04:29:43 -0700 Subject: [PATCH] create mask in the correct device --- longformer/diagonaled_mm_tvm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/longformer/diagonaled_mm_tvm.py b/longformer/diagonaled_mm_tvm.py index ce3a260..7ffb1cb 100644 --- a/longformer/diagonaled_mm_tvm.py +++ b/longformer/diagonaled_mm_tvm.py @@ -291,7 +291,7 @@ def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int): diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8) diagonal_mask[:-j] = 1 diagonals_list.append(diagonal_mask) - return torch.stack(diagonals_list, dim=-1).cuda() + return torch.stack(diagonals_list, dim=-1) @lru_cache() def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):