We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
device_get
1 parent 25e832a commit 9b64440Copy full SHA for 9b64440
python/jtorch/__init__.py
@@ -101,7 +101,12 @@ def retain_grad(x:Tensor, value:bool=True):
101
Tensor.nelement = lambda self: self.numel()
102
Tensor.cuda = lambda self: self
103
def device_get(x:Tensor):
104
- return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda")
+ if jt.has_cuda and jt.flags.use_cuda:
105
+ return device("cuda")
106
+ elif jt.compiler.has_acl and jt.flags.use_acl:
107
108
+ else:
109
+ return device("cpu")
110
Tensor.device = property(device_get)
111
112
def argmax(x: Var, dim=None, keepdim: bool = False):
0 commit comments