Open
Description
See crash at: https://github.com/tengyifei/playground/blob/master/torch-xla-device.ipynb
Minimal repro:
import torch
import torch_xla
a = torch.tensor([1.0], device="xla:1")
b = torch.tensor([1.0], device="xla:2")
print(a.device, b.device)
a + b
a + b.cpu()
import torch_xla.runtime as xr
xr.use_spmd()
c = torch.tensor([1.0], device="xla")
c.device, c + c