This repository has been archived by the owner on Mar 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinterpolation.py
90 lines (74 loc) · 2.83 KB
/
interpolation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from firedrake import *
import firedrake.supermeshing as supermesh
__all__ = ["supermesh_project", "point_interpolate", "adjoint_supermesh_project"]
def supermesh_project(src, tgt, check_mass=False, mixed_mass_matrix=None, solver=None):
"""
Hand-coded supermesh projection.
:arg src: source field.
:arg tgt: target field.
"""
source_space = src.function_space()
target_space = tgt.function_space()
# Step 1: Form the RHS:
# rhs := Mst * src
Mst = mixed_mass_matrix or supermesh.assemble_mixed_mass_matrix(source_space, target_space)
with tgt.dat.vec_ro as vt:
rhs = vt.copy()
with src.dat.vec_ro as vs:
Mst.mult(vs, rhs)
# Step 2: Solve the linear system for the target:
# Mt * tgt = rhs
ksp = solver or PETSc.KSP().create()
if solver is None:
Mt = assemble(inner(TrialFunction(target_space), TestFunction(target_space))*dx).M.handle
ksp.setOperators(Mt)
ksp.setFromOptions()
with tgt.dat.vec as vt:
ksp.solve(rhs, vt)
if check_mass:
assert np.allclose(assemble(src*dx), assemble(tgt*dx))
return tgt
def point_interpolate(src, tgt, tol=1.0e-10):
"""
Hand-coded point interpolation operator.
:arg src: source field.
:arg tgt: target field.
"""
try:
assert src.ufl_element().family() == 'Lagrange'
assert src.ufl_element().degree() == 1
assert tgt.ufl_element().family() == 'Lagrange'
assert tgt.ufl_element().degree() == 1
except AssertionError:
raise NotImplementedError
if not hasattr(tgt, 'function_space'):
tgt = Function(tgt)
mesh = tgt.function_space().mesh()
target_coords = mesh.coordinates.dat.data
for i in range(mesh.num_vertices()):
tgt.dat.data[i] = src.at(target_coords[i], tolerance=tol)
return tgt
def adjoint_supermesh_project(tgt_b, src_b, mixed_mass_matrix=None, solver=None):
"""
Hand-coded adjoint of a supermesh projection.
:arg tgt_b: seed vector in target space.
:arg src_b: adjoint supermesh projection into source space.
"""
source_space = src_b.function_space()
target_space = tgt_b.function_space()
# Adjoint of step 2: Solve the linear system for the target:
# Mt^T * sol = tgt_b
ksp = solver or PETSc.KSP().create()
if solver is None:
Mt = assemble(inner(TrialFunction(target_space), TestFunction(target_space))*dx).M.handle
ksp.setOperators(Mt.transpose())
ksp.setFromOptions()
with tgt_b.dat.vec_ro as rhs:
sol = rhs.copy()
ksp.solve(rhs, sol)
# Adjoint of 1: Multiply with the tranpose mass matrix
# src_b := Mst^T * sol
Mst = mixed_mass_matrix or supermesh.assemble_mixed_mass_matrix(source_space, target_space)
with src_b.dat.vec_ro as vs:
Mst.multTranspose(sol, vs)
return src_b