Skip to content

Commit f210361

Browse files
add torch interface numpy path device identify
1 parent 80e7fa9 commit f210361

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
- Fixed `TorchLayer` parameter list auto registeration
3030

31+
- Pytorch interface is now device aware (#25)
32+
3133
## 0.2.1
3234

3335
### Added

tensorcircuit/interfaces/torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
7070
# (x, )
7171
if len(ctx.xdtype) == 1:
7272
ctx.xdtype = ctx.xdtype[0]
73-
73+
if not enable_dlpack:
74+
ctx.device = (backend.tree_flatten(x)[0][0]).device
7475
x = general_args_to_backend(x, enable_dlpack=enable_dlpack)
7576
y = fun(*x)
7677
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
@@ -80,6 +81,8 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8081
y = general_args_to_backend(
8182
y, target_backend="pytorch", enable_dlpack=enable_dlpack
8283
)
84+
if not enable_dlpack:
85+
y = backend.tree_map(lambda s: s.to(device=ctx.device), y)
8386
return y
8487

8588
@staticmethod
@@ -101,6 +104,8 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
101104
target_backend="pytorch",
102105
enable_dlpack=enable_dlpack,
103106
)
107+
if not enable_dlpack:
108+
r = backend.tree_map(lambda s: s.to(device=ctx.device), r)
104109
if not is_sequence(r):
105110
return (r,)
106111
return r

0 commit comments

Comments
 (0)