Skip to content

Commit 07424f7

Browse files
add higher version dlpack support
1 parent 405c3d8 commit 07424f7

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

tensorcircuit/backends/jax_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(self) -> None:
197197
"backend or install Jax."
198198
)
199199
from jax.experimental import sparse
200+
import jax.scipy
200201

201202
try:
202203
import optax

tensorcircuit/interfaces/torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8989
def backward(ctx: Any, *grad_y: Any) -> Any:
9090
if len(grad_y) == 1:
9191
grad_y = grad_y[0]
92+
grad_y = backend.tree_map(lambda s: s.contiguous(), grad_y)
9293
grad_y = general_args_to_backend(
9394
grad_y, dtype=ctx.ydtype, enable_dlpack=enable_dlpack
9495
)

tests/test_interfaces.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytest_lazyfixture import lazy_fixture as lf
77
from scipy import optimize
88
import tensorflow as tf
9+
import jax
910

1011
thisfile = os.path.abspath(__file__)
1112
modulepath = os.path.dirname(os.path.dirname(thisfile))
@@ -47,7 +48,7 @@ def f(param):
4748

4849
param = torch.ones([4, n], requires_grad=True)
4950
l = f_jit_torch(param)
50-
l = l**2
51+
l = l ** 2
5152
l.backward()
5253

5354
pg = param.grad
@@ -95,7 +96,7 @@ def f2(paramzz, paramx):
9596
np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
9697

9798
def f3(x):
98-
return tc.backend.real(x**2)
99+
return tc.backend.real(x ** 2)
99100

100101
f3_torch = tc.interfaces.torch_interface(f3)
101102
param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
@@ -106,6 +107,26 @@ def f3(x):
106107
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
107108

108109

110+
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
111+
@pytest.mark.xfail(
112+
(int(tf.__version__.split(".")[1]) < 9)
113+
or (int("".join(jax.__version__.split(".")[1:])) < 314),
114+
reason="version too low for tf or jax",
115+
)
116+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
117+
def test_torch_interface_dlpack_complex(backend):
118+
def f3(x):
119+
return tc.backend.real(x ** 2)
120+
121+
f3_torch = tc.interfaces.torch_interface(f3, enable_dlpack=True)
122+
param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
123+
l3 = f3_torch(param3)
124+
l3 = torch.sum(l3)
125+
l3.backward()
126+
pg = param3.grad
127+
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
128+
129+
109130
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
110131
@pytest.mark.xfail(reason="see comment link below")
111132
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])

0 commit comments

Comments
 (0)