Skip to content

Commit 88e019a

Browse files
authored
Random Cofunctions (#4237)
1 parent 6a5f4b1 commit 88e019a

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

firedrake/randomfunctiongen.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
========
44
55
This module wraps `numpy.random <https://numpy.org/doc/stable/reference/random/index.html>`__,
6-
and enables users to generate a randomised :class:`.Function` from a :class:`.FunctionSpace`.
6+
and enables users to generate a randomised :class:`.Function` from a :class:`.WithGeometry` or :class:`.FiredrakeDualSpace`.
77
This module inherits almost all attributes from `numpy.random <https://numpy.org/doc/stable/reference/random/index.html>`__ with the following changes:
88
99
Generator
1010
---------
1111
1212
A :class:`.Generator` wraps `numpy.random.Generator <https://numpy.org/doc/stable/reference/random/generator.html>`__.
1313
:class:`.Generator` inherits almost all distribution methods from `numpy.random.Generator <https://numpy.org/doc/stable/reference/random/generator.html>`__,
14-
and they can be used to generate a randomised :class:`.Function` by passing a :class:`.FunctionSpace` as the first argument.
14+
and they can be used to generate a randomised :class:`.Function` by passing a :class:`.WithGeometry` or :class:`.FiredrakeDualSpace` as the first argument.
1515
1616
Example:
1717
@@ -106,7 +106,7 @@
106106

107107
from firedrake.function import Function
108108
from pyop2.mpi import COMM_WORLD
109-
from ufl import FunctionSpace
109+
from ufl.functionspace import BaseFunctionSpace
110110

111111
_deprecated_attributes = ['RandomGenerator', ]
112112

@@ -280,7 +280,7 @@ def funcgen(c_a):
280280

281281
@add_doc_string(getattr(_Base, c_a).__doc__)
282282
def func(self, *args, **kwargs):
283-
if len(args) > 0 and isinstance(args[0], FunctionSpace):
283+
if len(args) > 0 and isinstance(args[0], BaseFunctionSpace):
284284
raise NotImplementedError("%s.%s does not take FunctionSpace as argument" % (module_attr, c_a))
285285
else:
286286
return getattr(super(_Wrapper, self), c_a)(*args, **kwargs)
@@ -294,7 +294,7 @@ def func(self, *args, **kwargs):
294294
def funcgen(c_a):
295295
@add_doc_string(getattr(_Base, c_a).__doc__)
296296
def func(self, *args, **kwargs):
297-
if len(args) > 0 and isinstance(args[0], FunctionSpace):
297+
if len(args) > 0 and isinstance(args[0], BaseFunctionSpace):
298298
# Extract size from V
299299
if 'size' in kwargs.keys():
300300
raise TypeError("Cannot specify 'size' when generating a random function from 'V'")

tests/firedrake/adjoint/test_split_and_subfunctions.py

-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def test_writing_to_subfunctions():
209209
u.assign(kappa)
210210
usub *= 2
211211
J = assemble(inner(u, u) * dx)
212-
print(f"{type(J) = }")
213212
rf = ReducedFunctional(J, Control(kappa), tape=tape)
214213
pause_annotation()
215214

tests/firedrake/randomfunctiongen/test_randomfunction.py

+36
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,39 @@ def test_randomfunc_generator_spawn():
129129
assert all([child.bit_generator._seed_seq.entropy == parent.bit_generator._seed_seq.entropy for child in children])
130130
assert all([child.bit_generator._seed_seq.spawn_key == parent.bit_generator._seed_seq.spawn_key + (i, ) for i, child in enumerate(children)])
131131
assert all([child.bit_generator._seed_seq.pool_size == parent.bit_generator._seed_seq.pool_size for child in children])
132+
133+
134+
@pytest.mark.parametrize("brng", ['PCG64'])
135+
@pytest.mark.parametrize("meth_args", [('beta', (0.3, 0.5))])
136+
def test_random_cofunction(brng, meth_args):
137+
meth, args = meth_args
138+
139+
mesh = UnitSquareMesh(10, 10)
140+
V0 = VectorFunctionSpace(mesh, "CG", 1)
141+
V1 = FunctionSpace(mesh, "CG", 1)
142+
V = V0 * V1
143+
V = V.dual()
144+
145+
seed = 123456789
146+
# Original
147+
bgen = getattr(randomgen, brng)(seed=seed)
148+
if brng == 'PCG64':
149+
state = bgen.state
150+
state['state'] = {'state': seed, 'inc': V.comm.Get_rank()}
151+
bgen.state = state
152+
rg_base = randomgen.Generator(bgen)
153+
# Firedrake wrapper
154+
fgen = getattr(randomfunctiongen, brng)(seed=seed)
155+
if brng == 'PCG64':
156+
state = fgen.state
157+
state['state'] = {'state': seed, 'inc': V.comm.Get_rank()}
158+
fgen.state = state
159+
rg_wrap = randomfunctiongen.Generator(fgen)
160+
for i in range(1, 10):
161+
f = getattr(rg_wrap, meth)(V, *args)
162+
with f.dat.vec_ro as v:
163+
if meth in ('rand', 'randn'):
164+
assert np.allclose(getattr(rg_base, meth)(v.local_size), v.array[:])
165+
else:
166+
kwargs = {'size': (v.local_size, )}
167+
assert np.allclose(getattr(rg_base, meth)(*args, **kwargs), v.array[:])

0 commit comments

Comments
 (0)