Skip to content

Commit

Permalink
create mask in the correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
ibeltagy committed Apr 30, 2020
1 parent 89e3980 commit d574a10
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion longformer/diagonaled_mm_tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
diagonal_mask[:-j] = 1
diagonals_list.append(diagonal_mask)
return torch.stack(diagonals_list, dim=-1).cuda()
return torch.stack(diagonals_list, dim=-1)

@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
Expand Down

0 comments on commit d574a10

Please sign in to comment.