Skip to content

torch_xla::runtime::PjRtComputationClient::ExecuteReplicated() crash after RuntimeError #8973

Open
@tengyifei

Description

@tengyifei

See crash at: https://github.com/tengyifei/playground/blob/master/torch-xla-device.ipynb

Minimal repro:

import torch
import torch_xla

a = torch.tensor([1.0], device="xla:1")
b = torch.tensor([1.0], device="xla:2") 
print(a.device, b.device)

a + b

a + b.cpu()

import torch_xla.runtime as xr
xr.use_spmd()

c = torch.tensor([1.0], device="xla") 
c.device, c + c

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingxla:tpuTPU specific issues and PRs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions