Skip to content

Commit 9b64440

Browse files
committed
add Huawei ACL case for device_get
1 parent 25e832a commit 9b64440

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/jtorch/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ def retain_grad(x:Tensor, value:bool=True):
101101
Tensor.nelement = lambda self: self.numel()
102102
Tensor.cuda = lambda self: self
103103
def device_get(x:Tensor):
104-
return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
104+
if jt.has_cuda and jt.flags.use_cuda:
105+
return device("cuda")
106+
elif jt.compiler.has_acl and jt.flags.use_acl:
107+
return device("cuda")
108+
else:
109+
return device("cpu")
105110
Tensor.device = property(device_get)
106111

107112
def argmax(x: Var, dim=None, keepdim: bool = False):

0 commit comments

Comments
 (0)