diff --git a/arraycontext/context.py b/arraycontext/context.py index 36a7acee..2b7d623a 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -275,6 +275,7 @@ class ArrayContext(ABC): .. automethod:: tag .. automethod:: tag_axis .. automethod:: compile + .. automethod:: solve """ array_types: Tuple[type, ...] = () @@ -536,6 +537,9 @@ def permits_advanced_indexing(self) -> bool: *True* if the arrays support :mod:`numpy`'s advanced indexing semantics. """ + + def solve(self, + # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e..bab04565 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1549,6 +1549,37 @@ class ExampleTag(Tag): assert not ary.axes[1].tags_of_type(ExampleTag) +# {{{ test_solve_sqrt + +class SquareRootSolve(Tag): + pass + + +def test_solve_sqrt(actx_factory): + actx = actx_factory() + + def get_root(b): + def f(x): + return x**2 - b + + def jacobian(x): + return 2*x + + initial = actx.from_numpy(2) + return actx.solve(f, initial, jacobian=jacobian, + tol=1.48e-08, maxiter=50, rtol=0.0, + tags=(SquareRootSolve(),)) + + get_root_compiled = actx.compile(get_root) + + ... test that get_root computes square roots ... + + if isinstance(actx, (EagerJAXArrayContext, PytatoJAX...)): + ... check that the derivative does what we want ... + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: