Skip to content

Commit f4cdb3f

Browse files
authored
[Relax][PyTorch] Add support for ge, gt, le, mod, ne ops (#17664)
* Update fx_translator.py * Update test_frontend_from_fx.py
1 parent 0d42dc4 commit f4cdb3f

File tree

2 files changed

+238
-57
lines changed

2 files changed

+238
-57
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,13 +662,18 @@ def create_convert_map(
662662
"add": self._binary_op(relax.op.add, operator.add),
663663
"eq": self._binary_op(relax.op.equal, operator.eq),
664664
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
665+
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
666+
"gt": self._binary_op(relax.op.greater, operator.gt),
665667
"iadd": self._binary_op(relax.op.add, operator.add),
668+
"le": self._binary_op(relax.op.less_equal, operator.le),
666669
"lt": self._binary_op(relax.op.less, operator.lt),
667670
"matmul": self._binary_op(
668671
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
669672
),
670673
"max": self._binary_op(relax.op.maximum, max),
674+
"mod": self._binary_op(relax.op.mod, operator.mod),
671675
"mul": self._binary_op(relax.op.multiply, operator.mul),
676+
"ne": self._binary_op(relax.op.not_equal, operator.ne),
672677
"pow": self._binary_op(relax.op.power, operator.pow),
673678
"sub": self._binary_op(relax.op.subtract, operator.sub),
674679
"truediv": self._binary_op(relax.op.divide, operator.truediv),

tests/python/relax/test_frontend_from_fx.py

Lines changed: 233 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,209 @@ def main(
17591759
verify_model(LT1(), input_info1, {}, expected13)
17601760
verify_model(LT2(), input_info2, {}, expected14)
17611761

1762+
# Mod
1763+
class Mod1(Module):
1764+
def forward(self, lhs, rhs):
1765+
return lhs % rhs
1766+
1767+
@tvm.script.ir_module
1768+
class expected15:
1769+
@R.function
1770+
def main(
1771+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1772+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1773+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
1774+
# block 0
1775+
with R.dataflow():
1776+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, rhs_1)
1777+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
1778+
R.output(gv)
1779+
return gv
1780+
1781+
class Mod2(Module):
1782+
def forward(self, lhs):
1783+
return lhs % 1.0
1784+
1785+
@tvm.script.ir_module
1786+
class expected16:
1787+
@R.function
1788+
def main(
1789+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1790+
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
1791+
# block 0
1792+
with R.dataflow():
1793+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.mod(lhs_1, R.const(1.0))
1794+
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
1795+
R.output(gv)
1796+
return gv
1797+
1798+
verify_model(Mod1(), input_info1, {}, expected15)
1799+
verify_model(Mod2(), input_info2, {}, expected16)
1800+
1801+
# Ge
1802+
class Ge1(Module):
1803+
def forward(self, lhs, rhs):
1804+
return lhs >= rhs
1805+
1806+
@tvm.script.ir_module
1807+
class expected17:
1808+
@R.function
1809+
def main(
1810+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1811+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1812+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1813+
# block 0
1814+
with R.dataflow():
1815+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater_equal(lhs_1, rhs_1)
1816+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1817+
R.output(gv)
1818+
1819+
return gv
1820+
1821+
class Ge2(Module):
1822+
def forward(self, lhs):
1823+
return lhs >= 1.0
1824+
1825+
@tvm.script.ir_module
1826+
class expected18:
1827+
@R.function
1828+
def main(
1829+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1830+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1831+
# block 0
1832+
with R.dataflow():
1833+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater_equal(lhs_1, R.const(1.0))
1834+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1835+
R.output(gv)
1836+
1837+
return gv
1838+
1839+
verify_model(Ge1(), input_info1, {}, expected17)
1840+
verify_model(Ge2(), input_info2, {}, expected18)
1841+
1842+
# Gt
1843+
class Gt1(Module):
1844+
def forward(self, lhs, rhs):
1845+
return lhs > rhs
1846+
1847+
@tvm.script.ir_module
1848+
class expected19:
1849+
@R.function
1850+
def main(
1851+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1852+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1853+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1854+
# block 0
1855+
with R.dataflow():
1856+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, rhs_1)
1857+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1858+
R.output(gv)
1859+
1860+
return gv
1861+
1862+
class Gt2(Module):
1863+
def forward(self, lhs):
1864+
return lhs > 1.0
1865+
1866+
@tvm.script.ir_module
1867+
class expected20:
1868+
@R.function
1869+
def main(
1870+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1871+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1872+
# block 0
1873+
with R.dataflow():
1874+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lhs_1, R.const(1.0))
1875+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1876+
R.output(gv)
1877+
1878+
return gv
1879+
1880+
verify_model(Gt1(), input_info1, {}, expected19)
1881+
verify_model(Gt2(), input_info2, {}, expected20)
1882+
1883+
# Le
1884+
class Le1(Module):
1885+
def forward(self, lhs, rhs):
1886+
return lhs <= rhs
1887+
1888+
@tvm.script.ir_module
1889+
class expected21:
1890+
@R.function
1891+
def main(
1892+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1893+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1894+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1895+
# block 0
1896+
with R.dataflow():
1897+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less_equal(lhs_1, rhs_1)
1898+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1899+
R.output(gv)
1900+
1901+
return gv
1902+
1903+
class Le2(Module):
1904+
def forward(self, lhs):
1905+
return lhs <= 1.0
1906+
1907+
@tvm.script.ir_module
1908+
class expected22:
1909+
@R.function
1910+
def main(
1911+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1912+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1913+
# block 0
1914+
with R.dataflow():
1915+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less_equal(lhs_1, R.const(1.0))
1916+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1917+
R.output(gv)
1918+
1919+
return gv
1920+
1921+
verify_model(Le1(), input_info1, {}, expected21)
1922+
verify_model(Le2(), input_info2, {}, expected22)
1923+
1924+
# Ne
1925+
class Ne1(Module):
1926+
def forward(self, lhs, rhs):
1927+
return lhs != rhs
1928+
1929+
@tvm.script.ir_module
1930+
class expected23:
1931+
@R.function
1932+
def main(
1933+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1934+
rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1935+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1936+
# block 0
1937+
with R.dataflow():
1938+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(lhs_1, rhs_1)
1939+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1940+
R.output(gv)
1941+
1942+
return gv
1943+
1944+
class Ne2(Module):
1945+
def forward(self, lhs):
1946+
return lhs != 1.0
1947+
1948+
@tvm.script.ir_module
1949+
class expected24:
1950+
@R.function
1951+
def main(
1952+
lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
1953+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
1954+
# block 0
1955+
with R.dataflow():
1956+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(lhs_1, R.const(1.0))
1957+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
1958+
R.output(gv)
1959+
1960+
return gv
1961+
1962+
verify_model(Ne1(), input_info1, {}, expected23)
1963+
verify_model(Ne2(), input_info2, {}, expected24)
1964+
17621965

17631966
def test_size():
17641967
input_info = [([1, 3, 10, 10], "float32")]
@@ -1981,6 +2184,36 @@ def main(
19812184
verify_model(Unary(), input_info, {}, expected_unary)
19822185

19832186

2187+
operator_bool_unary = [
2188+
(torch.isnan, R.isnan),
2189+
(torch.isinf, R.isinf),
2190+
(torch.isfinite, R.isfinite),
2191+
]
2192+
2193+
2194+
@pytest.mark.parametrize("pytorch_op, relax_op", operator_bool_unary)
2195+
def test_bool_unary_ops(pytorch_op, relax_op):
2196+
input_info = [([1, 3, 10, 10], "float32")]
2197+
2198+
class Unary(Module):
2199+
def forward(self, input):
2200+
return pytorch_op(input)
2201+
2202+
@tvm.script.ir_module
2203+
class expected_unary:
2204+
@R.function
2205+
def main(
2206+
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2207+
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
2208+
with R.dataflow():
2209+
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1)
2210+
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
2211+
R.output(gv)
2212+
return gv
2213+
2214+
verify_model(Unary(), input_info, {}, expected_unary)
2215+
2216+
19842217
def test_extended_unary_ops():
19852218
input_info = [([1, 3, 10, 10], "float32")]
19862219

@@ -2201,63 +2434,6 @@ def main(
22012434
verify_model(LogSoftmax(), input_info, {}, expected_log_softmax)
22022435
verify_model(LogSoftmax2(), input_info, {}, expected_log_softmax)
22032436

2204-
# isfinite
2205-
class IsFinite(Module):
2206-
def forward(self, input):
2207-
return torch.isfinite(input)
2208-
2209-
@tvm.script.ir_module
2210-
class expected_isfinite:
2211-
@R.function
2212-
def main(
2213-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2214-
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
2215-
with R.dataflow():
2216-
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isfinite(input_1)
2217-
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
2218-
R.output(gv)
2219-
return gv
2220-
2221-
verify_model(IsFinite(), input_info, {}, expected_isfinite)
2222-
2223-
# isinf
2224-
class IsInf(Module):
2225-
def forward(self, input):
2226-
return torch.isinf(input)
2227-
2228-
@tvm.script.ir_module
2229-
class expected_isinf:
2230-
@R.function
2231-
def main(
2232-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2233-
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
2234-
with R.dataflow():
2235-
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isinf(input_1)
2236-
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
2237-
R.output(gv)
2238-
return gv
2239-
2240-
verify_model(IsInf(), input_info, {}, expected_isinf)
2241-
2242-
# isnan
2243-
class IsNan(Module):
2244-
def forward(self, input):
2245-
return torch.isnan(input)
2246-
2247-
@tvm.script.ir_module
2248-
class expected_isnan:
2249-
@R.function
2250-
def main(
2251-
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
2252-
) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
2253-
with R.dataflow():
2254-
lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.isnan(input_1)
2255-
gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
2256-
R.output(gv)
2257-
return gv
2258-
2259-
verify_model(IsNan(), input_info, {}, expected_isnan)
2260-
22612437
# relu
22622438
class ReLU0(Module):
22632439
def __init__(self):

0 commit comments

Comments
 (0)