See crash at: https://github.com/tengyifei/playground/blob/master/torch-xla-device.ipynb Minimal repro: ```python import torch import torch_xla a = torch.tensor([1.0], device="xla:1") b = torch.tensor([1.0], device="xla:2") print(a.device, b.device) a + b a + b.cpu() import torch_xla.runtime as xr xr.use_spmd() c = torch.tensor([1.0], device="xla") c.device, c + c ```