Skip to content

Commit 5634bed

Browse files
justinrporteraseyboldt
authored andcommitted
Fix grad_not_implemented import
import {theano,aesara}.gradient.grad_not_implemented as appropriate rather than always using aesara.gradient.grad_not_implemented. (cherry picked from commit c6a5dd6)
1 parent 82d9148 commit 5634bed

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sunode/wrappers/as_aesara.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from aesara.graph.basic import Constant, Variable
55
from aesara.graph.fg import MissingInputError
66
from aesara.graph.op import Op
7+
from aesara.gradient import grad_not_implemented
78
except ModuleNotFoundError:
89
import theano
910
import theano.tensor as aet
11+
from theano.gradient import grad_not_implemented
1012
if hasattr(theano, "gof"):
1113
from theano.gof.fg import MissingInputError
1214
from theano.gof.var import Constant, Variable
@@ -218,7 +220,7 @@ def grad(self, inputs, g):
218220
return [
219221
aet.zeros_like(inputs[0]),
220222
aet.sum(g[:, None, :] * sens, (0, -1)),
221-
aesara.gradient.grad_not_implemented(self, 2, inputs[-1])
223+
grad_not_implemented(self, 2, inputs[-1])
222224
]
223225

224226

@@ -257,7 +259,7 @@ def grad(self, inputs, g):
257259
y0, params, params_fixed = inputs
258260
backward = SolveODEAdjointBackward(self._solver, self._t0, self._tvals)
259261
lamda, gradient = backward(y0, params, params_fixed, g)
260-
return [-lamda, gradient, aesara.gradient.grad_not_implemented(self, 2, params_fixed)]
262+
return [-lamda, gradient, grad_not_implemented(self, 2, params_fixed)]
261263

262264

263265
class SolveODEAdjointBackward(Op):

0 commit comments

Comments
 (0)