From 425e78fb986cb08207c0a1d44e4fe15ee96145f3 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 1 Nov 2025 17:44:03 +0800 Subject: [PATCH] [#18362][relax.frontend.torch] Add temporary solution for pytorch op 'randn' --- .../torch/exported_program_translator.py | 17 +++++ .../test_frontend_from_exported_program.py | 70 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index a84c35e62234..180f167ec36c 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -64,6 +64,22 @@ def _reciprocal(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x)) + def _randn(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + + size = args[0] if isinstance(args[0], (list, tuple)) else args + + dtype = node.kwargs.get("dtype", "float32") + if isinstance(dtype, torch.dtype): + dtype = self._convert_data_type(dtype) + + shape = relax.ShapeExpr(size) + + # TODO: This is a temporary solution that returns zeros instead of random values + # since random initialization is mainly used during training, not inference. + # This should be updated once Relax adds proper random number generation support. + return self.block_builder.emit(relax.op.zeros(shape, dtype)) + ########## Neural Network ########## def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: @@ -835,6 +851,7 @@ def create_convert_map( "pad.default": self._pad, "pixel_shuffle.default": self._pixel_shuffle, "prelu.default": self._prelu, + "randn.default": self._randn, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), "relu_.default": self._unary_op(relax.op.nn.relu), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 657ade455bd7..136cd1418856 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6121,5 +6121,75 @@ def forward(self, x): np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) +def test_advanced_indexing_with_randn(): + """Test model with randn and advanced indexing write returning a tuple.""" + N = 5 + + class AdvancedIndexingModel(nn.Module): + def __init__(self): + super().__init__() + self.elu = nn.ELU() + + def forward(self, x): + L = torch.zeros(N, N, dtype=x.dtype, device=x.device) + idx = torch.arange(N, device=x.device) + v = torch.randn(N, device=x.device) + v = self.elu(v) + 1.0 + 1e-8 + L[idx, idx] = v + y = x + 1 + return y, L + + torch.manual_seed(0) + example_input = torch.randn(2, N) + model = AdvancedIndexingModel().eval() + + exported_program = export(model, (example_input,)) + + mod = from_exported_program(exported_program) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 5), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 5), dtype="float32"), R.Tensor((5, 5), dtype="float32")): + with R.dataflow(): + lv0 = R.zeros((5, 5), dtype="float32") + + # Use zeros instead of random normal distribution + lv1 = R.zeros((5,), dtype="float32") + + lv2 = R.nn.relu(lv1) + lv3 = R.add(lv2, R.const(1.0, "float32")) + v = R.add(lv3, R.const(1e-8, "float32")) + + idx = R.arange( + R.const(0, "int64"), R.const(5, "int64"), R.const(1, "int64"), dtype="int64" + ) + + L = R.tensor_update(lv0, (idx, idx), v) + y = R.add(x, R.const(1, "float32")) + + gv = R.tuple(y, L) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + target = "llvm" + dev = tvm.cpu() + + exe = relax.build(mod, target=target) + vm = relax.VirtualMachine(exe, dev) + tvm_res = vm["main"](tvm.nd.array(example_input.numpy())) + + torch_res = model(example_input) + + np.testing.assert_allclose(torch_res[0].numpy(), tvm_res[0].numpy(), rtol=1e-7, atol=1e-7) + + assert tvm_res[1].shape == (N, N) + assert tvm_res[1].dtype == "float32" + + if __name__ == "__main__": tvm.testing.main()