Skip to content

Commit 42a6e27

Browse files
committed
Fix assemble(slate.Tensor, diagonal=True)
1 parent 6e5c84a commit 42a6e27

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

firedrake/assemble.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,8 @@ def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing
11831183
@FormAssembler._skip_if_initialised
11841184
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
11851185
zero_bc_nodes=True, diagonal=False, weight=1.0):
1186+
if diagonal and isinstance(form, slate.TensorBase) and len(form.arguments()) == 2:
1187+
form = slate.DiagonalTensor(form)
11861188
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
11871189
self._weight = weight
11881190
self._diagonal = diagonal

firedrake/slate/slate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,13 +1387,12 @@ def arg_function_spaces(self):
13871387
"""Returns a tuple of function spaces that the tensor
13881388
is defined on.
13891389
"""
1390-
tensor, = self.operands
1391-
return tuple(arg.function_space() for arg in tensor.arguments())
1390+
return tuple(arg.function_space() for arg in self.arguments())
13921391

13931392
def arguments(self):
13941393
"""Returns a tuple of arguments associated with the tensor."""
13951394
tensor, = self.operands
1396-
return tensor.arguments()
1395+
return tensor.arguments()[:1]
13971396

13981397
def _output_string(self, prec=None):
13991398
"""Creates a string representation of the diagonal of a tensor."""

tests/firedrake/slate/test_linear_algebra.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,8 @@ def test_inverse_action(mat_type, rhs_type):
152152
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)
153153

154154

155-
@pytest.mark.parametrize("mat_type, rhs_type", [
156-
("slate", "slate"), ("slate", "form"), ("slate", "cofunction"),
157-
("aij", "cofunction"), ("aij", "form"),
158-
("matfree", "cofunction"), ("matfree", "form")])
155+
@pytest.mark.parametrize("rhs_type", ["slate", "form", "cofunction"])
156+
@pytest.mark.parametrize("mat_type", ["slate", "aij", "matfree"])
159157
def test_solve_interface(mat_type, rhs_type):
160158
mesh = UnitSquareMesh(1, 1)
161159
V = FunctionSpace(mesh, "HDivT", 0)
@@ -180,12 +178,8 @@ def test_solve_interface(mat_type, rhs_type):
180178
else:
181179
raise ValueError("Invalid rhs type")
182180

183-
sp = None
184-
if mat_type == "matfree":
185-
sp = {"pc_type": "none"}
186-
187181
x = Function(V)
188182
problem = LinearVariationalProblem(A, b, x, bcs=bcs)
189-
solver = LinearVariationalSolver(problem, solver_parameters=sp)
183+
solver = LinearVariationalSolver(problem)
190184
solver.solve()
191185
assert np.allclose(x.dat.data, f.dat.data, rtol=1.e-13)

0 commit comments

Comments
 (0)