Skip to content

Commit 2dd9746

Browse files
alexfiklinducer
authored andcommitted
make code containers functions instead of methods
1 parent 4e576a0 commit 2dd9746

File tree

8 files changed

+96
-98
lines changed

8 files changed

+96
-98
lines changed

pytential/collection.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ def _get_qbx_discretization(self, geometry, discr_stage):
274274
except KeyError:
275275
dofdesc = sym.DOFDescriptor(geometry, discr_stage)
276276

277+
from pytential.qbx.refinement import refiner_code_container
278+
wrangler = refiner_code_container(lpot_source._setup_actx).get_wrangler()
279+
277280
from pytential.qbx.refinement import _refine_for_global_qbx
278281
# NOTE: this adds the required discretizations to the cache
279-
_refine_for_global_qbx(self, dofdesc,
280-
lpot_source.refiner_code_container.get_wrangler(),
281-
_copy_collection=False)
282-
282+
_refine_for_global_qbx(self, dofdesc, wrangler, _copy_collection=False)
283283
discr = self._get_discr_from_cache(geometry, discr_stage)
284284

285285
return discr

pytential/linalg/proxy.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@ def partition_by_nodes(
8686
discr = places.get_discretization(dofdesc.geometry, dofdesc.discr_stage)
8787

8888
if tree_kind is not None:
89+
from pytential.qbx.utils import tree_code_container
90+
tcc = tree_code_container(lpot_source._setup_actx)
91+
8992
from arraycontext import thaw
90-
builder = lpot_source.tree_code_container.build_tree()
91-
tree, _ = builder(actx.queue,
93+
tree, _ = tcc.build_tree()(actx.queue,
9294
particles=flatten(
9395
thaw(discr.nodes(), actx), actx, leaf_class=DOFArray
9496
),
@@ -594,12 +596,12 @@ def prg():
594596

595597
# {{{ perform area query
596598

597-
builder = lpot_source.tree_code_container.build_tree()
598-
tree, _ = builder(actx.queue, sources,
599-
max_particles_in_box=max_particles_in_box)
599+
from pytential.qbx.utils import tree_code_container
600+
tcc = tree_code_container(lpot_source._setup_actx)
600601

601-
builder = lpot_source.tree_code_container.build_area_query()
602-
query, _ = builder(actx.queue, tree, pxy.centers, pxy.radii)
602+
tree, _ = tcc.build_tree()(actx.queue, sources,
603+
max_particles_in_box=max_particles_in_box)
604+
query, _ = tcc.build_area_query()(actx.queue, tree, pxy.centers, pxy.radii)
603605

604606
tree = tree.get(actx.queue)
605607
query = query.get(actx.queue)

pytential/qbx/__init__.py

+9-58
Original file line numberDiff line numberDiff line change
@@ -341,60 +341,6 @@ def copy(
341341

342342
# }}}
343343

344-
# {{{ code containers
345-
346-
@property
347-
def tree_code_container(self):
348-
@memoize_in(self._setup_actx, (
349-
QBXLayerPotentialSource, "tree_code_container"))
350-
def make_container():
351-
from pytential.qbx.utils import TreeCodeContainer
352-
return TreeCodeContainer(self._setup_actx)
353-
354-
return make_container()
355-
356-
@property
357-
def refiner_code_container(self):
358-
@memoize_in(self._setup_actx, (
359-
QBXLayerPotentialSource, "refiner_code_container"))
360-
def make_container():
361-
from pytential.qbx.refinement import RefinerCodeContainer
362-
return RefinerCodeContainer(
363-
self._setup_actx, self.tree_code_container)
364-
365-
return make_container()
366-
367-
@property
368-
def target_association_code_container(self):
369-
@memoize_in(self._setup_actx, (
370-
QBXLayerPotentialSource, "target_association_code_container"))
371-
def make_container():
372-
from pytential.qbx.target_assoc import TargetAssociationCodeContainer
373-
return TargetAssociationCodeContainer(
374-
self._setup_actx, self.tree_code_container)
375-
376-
return make_container()
377-
378-
@property
379-
def qbx_fmm_geometry_data_code_container(self):
380-
@memoize_in(self._setup_actx, (
381-
QBXLayerPotentialSource, "qbx_fmm_geometry_data_code_container"))
382-
def make_container(
383-
debug, ambient_dim, well_sep_is_n_away,
384-
from_sep_smaller_crit):
385-
from pytential.qbx.geometry import QBXFMMGeometryDataCodeContainer
386-
return QBXFMMGeometryDataCodeContainer(
387-
self._setup_actx,
388-
ambient_dim, self.tree_code_container, debug,
389-
_well_sep_is_n_away=well_sep_is_n_away,
390-
_from_sep_smaller_crit=from_sep_smaller_crit)
391-
392-
return make_container(
393-
self.debug, self.ambient_dim,
394-
self._well_sep_is_n_away, self._from_sep_smaller_crit)
395-
396-
# }}}
397-
398344
# {{{ internal API
399345

400346
@memoize_method
@@ -409,11 +355,16 @@ def qbx_fmm_geometry_data(self, places, name,
409355
:class:`pytential.target.TargetBase`
410356
instance
411357
"""
412-
from pytential.qbx.geometry import QBXFMMGeometryData
358+
from pytential.qbx.geometry import qbx_fmm_geometry_data_code_container
359+
code_container = qbx_fmm_geometry_data_code_container(
360+
self._setup_actx, self.ambient_dim,
361+
debug=self.debug,
362+
well_sep_is_n_away=self._well_sep_is_n_away,
363+
from_sep_smaller_crit=self._from_sep_smaller_crit)
413364

414-
return QBXFMMGeometryData(places, name,
415-
self.qbx_fmm_geometry_data_code_container,
416-
target_discrs_and_qbx_sides,
365+
from pytential.qbx.geometry import QBXFMMGeometryData
366+
return QBXFMMGeometryData(
367+
places, name, code_container, target_discrs_and_qbx_sides,
417368
target_association_tolerance=self.target_association_tolerance,
418369
tree_kind=self._tree_kind,
419370
debug=self.debug)

pytential/qbx/geometry.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import numpy as np
2525

26-
from pytools import memoize_method, log_process
26+
from pytools import memoize_method, memoize_in, log_process
2727
from arraycontext import PyOpenCLArrayContext, flatten, freeze
2828
from meshmode.dof_array import DOFArray
2929

@@ -107,16 +107,17 @@ class target_state(Enum): # noqa
107107

108108

109109
class QBXFMMGeometryDataCodeContainer(TreeCodeContainerMixin):
110-
def __init__(self, actx: PyOpenCLArrayContext, ambient_dim,
111-
tree_code_container, debug,
112-
_well_sep_is_n_away, _from_sep_smaller_crit):
110+
def __init__(self,
111+
actx: PyOpenCLArrayContext, ambient_dim: int, debug: bool,
112+
_well_sep_is_n_away: int, _from_sep_smaller_crit: str) -> None:
113+
self._setup_actx = actx
113114
self.ambient_dim = ambient_dim
114-
self.tree_code_container = tree_code_container
115+
self.debug = debug
115116
self._well_sep_is_n_away = _well_sep_is_n_away
116117
self._from_sep_smaller_crit = _from_sep_smaller_crit
117118

118-
self._setup_actx = actx.clone()
119-
self.debug = debug
119+
from pytential.qbx.utils import tree_code_container
120+
self.tree_code_container = tree_code_container(actx)
120121

121122
@memoize_method
122123
def copy_targets_kernel(self):
@@ -260,6 +261,26 @@ def rotation_classes_builder(self):
260261
from boxtree.rotation_classes import RotationClassesBuilder
261262
return RotationClassesBuilder(self._setup_actx.context)
262263

264+
265+
def qbx_fmm_geometry_data_code_container(
266+
actx: PyOpenCLArrayContext, ambient_dim: int, *,
267+
debug: bool,
268+
well_sep_is_n_away: int,
269+
from_sep_smaller_crit: str) -> QBXFMMGeometryDataCodeContainer:
270+
@memoize_in(actx, (
271+
QBXFMMGeometryDataCodeContainer, qbx_fmm_geometry_data_code_container))
272+
def make_container(
273+
_ambient_dim, _debug,
274+
_well_sep_is_n_away, _from_sep_smaller_crit):
275+
return QBXFMMGeometryDataCodeContainer(
276+
actx, _ambient_dim, _debug,
277+
_well_sep_is_n_away=_well_sep_is_n_away,
278+
_from_sep_smaller_crit=_from_sep_smaller_crit)
279+
280+
return make_container(
281+
ambient_dim, debug,
282+
well_sep_is_n_away, from_sep_smaller_crit)
283+
263284
# }}}
264285

265286

@@ -759,9 +780,9 @@ def user_target_to_center(self):
759780
PointsTarget(target_info.targets[:, self.ncenters:]),
760781
target_side_prefs.astype(np.int32))]
761782

783+
from pytential.qbx.target_assoc import target_association_code_container
762784
target_association_wrangler = (
763-
self.lpot_source.target_association_code_container
764-
.get_wrangler(actx))
785+
target_association_code_container(actx).get_wrangler(actx))
765786

766787
tgt_assoc_result = associate_targets_to_qbx_centers(
767788
self.places,

pytential/qbx/refinement.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from arraycontext import PyOpenCLArrayContext, flatten
3232
from meshmode.dof_array import DOFArray
3333

34-
from pytools import memoize_method
34+
from pytools import memoize_method, memoize_in
3535
from boxtree.area_query import AreaQueryElementwiseTemplate
3636
from boxtree.tools import InlineBinarySearch
3737
from pytential.qbx.utils import (
@@ -219,9 +219,11 @@
219219

220220
class RefinerCodeContainer(TreeCodeContainerMixin):
221221

222-
def __init__(self, actx: PyOpenCLArrayContext, tree_code_container):
222+
def __init__(self, actx: PyOpenCLArrayContext):
223223
self.array_context = actx
224-
self.tree_code_container = tree_code_container
224+
225+
from pytential.qbx.utils import tree_code_container
226+
self.tree_code_container = tree_code_container(actx)
225227

226228
@memoize_method
227229
def expansion_disk_undisturbed_by_sources_checker(
@@ -271,6 +273,14 @@ def element_prop_threshold_checker(self):
271273
def get_wrangler(self):
272274
return RefinerWrangler(self.array_context, self)
273275

276+
277+
def refiner_code_container(actx: PyOpenCLArrayContext) -> RefinerCodeContainer:
278+
@memoize_in(actx, (RefinerCodeContainer, refiner_code_container))
279+
def make_container():
280+
return RefinerCodeContainer(actx)
281+
282+
return make_container()
283+
274284
# }}}
275285

276286

@@ -964,8 +974,8 @@ def refine_geometry_collection(places,
964974
if not isinstance(lpot_source, QBXLayerPotentialSource):
965975
continue
966976

967-
_refine_for_global_qbx(places, dofdesc,
968-
lpot_source.refiner_code_container.get_wrangler(),
977+
wrangler = refiner_code_container(lpot_source._setup_actx).get_wrangler()
978+
_refine_for_global_qbx(places, dofdesc, wrangler,
969979
group_factory=group_factory,
970980
kernel_length_scale=kernel_length_scale,
971981
scaled_max_curvature_threshold=scaled_max_curvature_threshold,

pytential/qbx/target_assoc.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import numpy as np
2626

27-
from pytools import memoize_method
27+
from pytools import memoize_method, memoize_in
2828
from boxtree.tools import DeviceDataRecord
2929
from boxtree.area_query import AreaQueryElementwiseTemplate
3030
from boxtree.tools import InlineBinarySearch
@@ -440,9 +440,11 @@ class QBXTargetAssociation(DeviceDataRecord):
440440

441441
class TargetAssociationCodeContainer(TreeCodeContainerMixin):
442442

443-
def __init__(self, actx: PyOpenCLArrayContext, tree_code_container):
443+
def __init__(self, actx: PyOpenCLArrayContext):
444444
self.array_context = actx
445-
self.tree_code_container = tree_code_container
445+
446+
from pytential.qbx.utils import tree_code_container
447+
self.tree_code_container = tree_code_container(actx)
446448

447449
@property
448450
def cl_context(self):
@@ -493,6 +495,16 @@ def get_wrangler(self, actx: PyOpenCLArrayContext):
493495
return TargetAssociationWrangler(actx, code_container=self)
494496

495497

498+
def target_association_code_container(
499+
actx: PyOpenCLArrayContext) -> TargetAssociationCodeContainer:
500+
@memoize_in(actx, (
501+
TargetAssociationCodeContainer, target_association_code_container))
502+
def make_container():
503+
return TargetAssociationCodeContainer(actx)
504+
505+
return make_container()
506+
507+
496508
class TargetAssociationWrangler(TreeWranglerBase):
497509

498510
@log_process(logger)

pytential/qbx/utils.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import numpy as np
2626

27-
from pytools import memoize_method, log_process
27+
from pytools import memoize_method, memoize_in, log_process
2828
from arraycontext import PyOpenCLArrayContext
2929
from meshmode.dof_array import DOFArray
3030

@@ -91,6 +91,14 @@ def build_area_query(self):
9191
from boxtree.area_query import AreaQueryBuilder
9292
return AreaQueryBuilder(self.array_context.context)
9393

94+
95+
def tree_code_container(actx: PyOpenCLArrayContext) -> TreeCodeContainer:
96+
@memoize_in(actx, (TreeCodeContainer, tree_code_container))
97+
def make_container():
98+
return TreeCodeContainer(actx)
99+
100+
return make_container()
101+
94102
# }}}
95103

96104

@@ -110,6 +118,7 @@ def peer_list_finder(self):
110118
def particle_list_filter(self):
111119
return self.tree_code_container.particle_list_filter()
112120

121+
113122
# }}}
114123

115124

test/test_global_qbx.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,8 @@ def targets_from_sources(sign, dist, dim=2):
401401
# {{{ run target associator and check
402402

403403
from pytential.qbx.target_assoc import (
404-
TargetAssociationCodeContainer, associate_targets_to_qbx_centers)
405-
406-
from pytential.qbx.utils import TreeCodeContainer
407-
code_container = TargetAssociationCodeContainer(
408-
actx, TreeCodeContainer(actx))
404+
target_association_code_container, associate_targets_to_qbx_centers)
405+
code_container = target_association_code_container(actx)
409406

410407
target_assoc = (
411408
associate_targets_to_qbx_centers(
@@ -543,13 +540,9 @@ def test_target_association_failure(actx_factory):
543540
)
544541

545542
from pytential.qbx.target_assoc import (
546-
TargetAssociationCodeContainer, associate_targets_to_qbx_centers,
543+
target_association_code_container, associate_targets_to_qbx_centers,
547544
QBXTargetAssociationFailedException)
548-
549-
from pytential.qbx.utils import TreeCodeContainer
550-
551-
code_container = TargetAssociationCodeContainer(
552-
actx, TreeCodeContainer(actx))
545+
code_container = target_association_code_container(actx)
553546

554547
with pytest.raises(QBXTargetAssociationFailedException):
555548
associate_targets_to_qbx_centers(

0 commit comments

Comments
 (0)