File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change 2828
2929- Fixed ` TorchLayer ` parameter list auto registeration
3030
31+ - Pytorch interface is now device aware (#25 )
32+
3133## 0.2.1
3234
3335### Added
Original file line number Diff line number Diff line change @@ -70,7 +70,8 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
7070 # (x, )
7171 if len (ctx .xdtype ) == 1 :
7272 ctx .xdtype = ctx .xdtype [0 ]
73-
73+ if not enable_dlpack :
74+ ctx .device = (backend .tree_flatten (x )[0 ][0 ]).device
7475 x = general_args_to_backend (x , enable_dlpack = enable_dlpack )
7576 y = fun (* x )
7677 ctx .ydtype = backend .tree_map (lambda s : s .dtype , y )
@@ -80,6 +81,8 @@ def forward(ctx: Any, *x: Any) -> Any: # type: ignore
8081 y = general_args_to_backend (
8182 y , target_backend = "pytorch" , enable_dlpack = enable_dlpack
8283 )
84+ if not enable_dlpack :
85+ y = backend .tree_map (lambda s : s .to (device = ctx .device ), y )
8386 return y
8487
8588 @staticmethod
@@ -101,6 +104,8 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
101104 target_backend = "pytorch" ,
102105 enable_dlpack = enable_dlpack ,
103106 )
107+ if not enable_dlpack :
108+ r = backend .tree_map (lambda s : s .to (device = ctx .device ), r )
104109 if not is_sequence (r ):
105110 return (r ,)
106111 return r
You can’t perform that action at this time.
0 commit comments