Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ def _reciprocal(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x))

def _randn(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)

size = args[0] if isinstance(args[0], (list, tuple)) else args

dtype = node.kwargs.get("dtype", "float32")
if isinstance(dtype, torch.dtype):
dtype = self._convert_data_type(dtype)

shape = relax.ShapeExpr(size)

# TODO: This is a temporary solution that returns zeros instead of random values
# since random initialization is mainly used during training, not inference.
# This should be updated once Relax adds proper random number generation support.
return self.block_builder.emit(relax.op.zeros(shape, dtype))

########## Neural Network ##########

def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
Expand Down Expand Up @@ -835,6 +851,7 @@ def create_convert_map(
"pad.default": self._pad,
"pixel_shuffle.default": self._pixel_shuffle,
"prelu.default": self._prelu,
"randn.default": self._randn,
"reciprocal.default": self._reciprocal,
"relu.default": self._unary_op(relax.op.nn.relu),
"relu_.default": self._unary_op(relax.op.nn.relu),
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6121,5 +6121,75 @@ def forward(self, x):
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)


def test_advanced_indexing_with_randn():
"""Test model with randn and advanced indexing write returning a tuple."""
N = 5

class AdvancedIndexingModel(nn.Module):
def __init__(self):
super().__init__()
self.elu = nn.ELU()

def forward(self, x):
L = torch.zeros(N, N, dtype=x.dtype, device=x.device)
idx = torch.arange(N, device=x.device)
v = torch.randn(N, device=x.device)
v = self.elu(v) + 1.0 + 1e-8
L[idx, idx] = v
y = x + 1
return y, L

torch.manual_seed(0)
example_input = torch.randn(2, N)
model = AdvancedIndexingModel().eval()

exported_program = export(model, (example_input,))

mod = from_exported_program(exported_program)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 5), dtype="float32")
) -> R.Tuple(R.Tensor((2, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")):
with R.dataflow():
lv0 = R.zeros((5, 5), dtype="float32")

# Use zeros instead of random normal distribution
lv1 = R.zeros((5,), dtype="float32")

lv2 = R.nn.relu(lv1)
lv3 = R.add(lv2, R.const(1.0, "float32"))
v = R.add(lv3, R.const(1e-8, "float32"))

idx = R.arange(
R.const(0, "int64"), R.const(5, "int64"), R.const(1, "int64"), dtype="int64"
)

L = R.tensor_update(lv0, (idx, idx), v)
y = R.add(x, R.const(1, "float32"))

gv = R.tuple(y, L)
R.output(gv)
return gv

tvm.ir.assert_structural_equal(mod, Expected)

target = "llvm"
dev = tvm.cpu()

exe = relax.build(mod, target=target)
vm = relax.VirtualMachine(exe, dev)
tvm_res = vm["main"](tvm.nd.array(example_input.numpy()))

torch_res = model(example_input)

np.testing.assert_allclose(torch_res[0].numpy(), tvm_res[0].numpy(), rtol=1e-7, atol=1e-7)

assert tvm_res[1].shape == (N, N)
assert tvm_res[1].dtype == "float32"


if __name__ == "__main__":
tvm.testing.main()