66from pytest_lazyfixture import lazy_fixture as lf
77from scipy import optimize
88import tensorflow as tf
9+ import jax
910
1011thisfile = os .path .abspath (__file__ )
1112modulepath = os .path .dirname (os .path .dirname (thisfile ))
@@ -47,7 +48,7 @@ def f(param):
4748
4849 param = torch .ones ([4 , n ], requires_grad = True )
4950 l = f_jit_torch (param )
50- l = l ** 2
51+ l = l ** 2
5152 l .backward ()
5253
5354 pg = param .grad
@@ -95,7 +96,7 @@ def f2(paramzz, paramx):
9596 np .testing .assert_allclose (pg [0 , 0 ], - 0.41609 , atol = 1e-5 )
9697
9798 def f3 (x ):
98- return tc .backend .real (x ** 2 )
99+ return tc .backend .real (x ** 2 )
99100
100101 f3_torch = tc .interfaces .torch_interface (f3 )
101102 param3 = torch .ones ([2 ], dtype = torch .complex64 , requires_grad = True )
@@ -106,6 +107,26 @@ def f3(x):
106107 np .testing .assert_allclose (pg , 2 * np .ones ([2 ]).astype (np .complex64 ), atol = 1e-5 )
107108
108109
110+ @pytest .mark .skipif (is_torch is False , reason = "torch not installed" )
111+ @pytest .mark .xfail (
112+ (int (tf .__version__ .split ("." )[1 ]) < 9 )
113+ or (int ("" .join (jax .__version__ .split ("." )[1 :])) < 314 ),
114+ reason = "version too low for tf or jax" ,
115+ )
116+ @pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" )])
117+ def test_torch_interface_dlpack_complex (backend ):
118+ def f3 (x ):
119+ return tc .backend .real (x ** 2 )
120+
121+ f3_torch = tc .interfaces .torch_interface (f3 , enable_dlpack = True )
122+ param3 = torch .ones ([2 ], dtype = torch .complex64 , requires_grad = True )
123+ l3 = f3_torch (param3 )
124+ l3 = torch .sum (l3 )
125+ l3 .backward ()
126+ pg = param3 .grad
127+ np .testing .assert_allclose (pg , 2 * np .ones ([2 ]).astype (np .complex64 ), atol = 1e-5 )
128+
129+
109130@pytest .mark .skipif (is_torch is False , reason = "torch not installed" )
110131@pytest .mark .xfail (reason = "see comment link below" )
111132@pytest .mark .parametrize ("backend" , [lf ("tfb" ), lf ("jaxb" )])
0 commit comments