Skip to content

Commit 7ed33f4

Browse files
committed
add recursive clustering and skeletonization
1 parent a3c0351 commit 7ed33f4

13 files changed

+622
-176
lines changed

doc/conf.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
"DOFDescriptorLike": "pytential.symbolic.dof_desc.DOFDescriptorLike",
2323
}
2424

25+
nitpick_ignore_regex = [
26+
["py:class", r"_ProxyNeighborEvaluationResult"],
27+
]
28+
2529
intersphinx_mapping = {
2630
"https://docs.python.org/3/": None,
2731
"https://numpy.org/doc/stable/": None,

doc/linalg.rst

+18-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ scheme is used:
88
component of the Stokeslet.
99
* ``cluster`` refers to a piece of a ``block`` as used by the recursive
1010
proxy-based skeletonization of the direct solver algorithms. Clusters
11-
are represented by a :class:`~pytential.linalg.TargetAndSourceClusterList`.
11+
are represented by a :class:`~pytential.linalg.utils.TargetAndSourceClusterList`.
1212

1313
GMRES
1414
-----
@@ -20,17 +20,31 @@ GMRES
2020
Hierarchical Direct Solver
2121
--------------------------
2222

23+
.. note::
24+
25+
High-level API for direct solvers is in progress.
26+
27+
Low-level Functionality
28+
-----------------------
29+
2330
.. warning::
2431

2532
All the classes and routines in this module are experimental and the
2633
API can change at any point.
2734

35+
.. automodule:: pytential.linalg.skeletonization
36+
.. automodule:: pytential.linalg.cluster
2837
.. automodule:: pytential.linalg.proxy
29-
.. automodule:: pytential.linalg.utils
3038

31-
Internal Functionality
32-
----------------------
39+
Internal Functionality and Utilities
40+
------------------------------------
41+
42+
.. warning::
3343

44+
All the classes and routines in this module are experimental and the
45+
API can change at any point.
46+
47+
.. automodule:: pytential.linalg.utils
3448
.. automodule:: pytential.linalg.direct_solver_symbolic
3549

3650
.. vim: sw=4:tw=75:fdm=marker

pytential/linalg/__init__.py

-16
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,9 @@
2525
make_index_list, make_index_cluster_cartesian_product,
2626
interp_decomp,
2727
)
28-
from pytential.linalg.proxy import (
29-
ProxyClusterGeometryData, ProxyPointTarget, ProxyPointSource,
30-
ProxyGeneratorBase, ProxyGenerator, QBXProxyGenerator,
31-
partition_by_nodes, gather_cluster_neighbor_points,
32-
)
33-
from pytential.linalg.skeletonization import (
34-
SkeletonizationWrangler, make_skeletonization_wrangler,
35-
SkeletonizationResult, skeletonize_by_proxy,
36-
)
3728

3829
__all__ = (
3930
"IndexList", "TargetAndSourceClusterList",
4031
"make_index_list", "make_index_cluster_cartesian_product",
4132
"interp_decomp",
42-
43-
"ProxyClusterGeometryData", "ProxyPointTarget", "ProxyPointSource",
44-
"ProxyGeneratorBase", "ProxyGenerator", "QBXProxyGenerator",
45-
"partition_by_nodes", "gather_cluster_neighbor_points",
46-
47-
"SkeletonizationWrangler", "make_skeletonization_wrangler",
48-
"SkeletonizationResult", "skeletonize_by_proxy",
4933
)

pytential/linalg/cluster.py

+277
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
__copyright__ = "Copyright (C) 2022 Alexandru Fikl"
2+
3+
__license__ = """
4+
Permission is hereby granted, free of charge, to any person obtaining a copy
5+
of this software and associated documentation files (the "Software"), to deal
6+
in the Software without restriction, including without limitation the rights
7+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
copies of the Software, and to permit persons to whom the Software is
9+
furnished to do so, subject to the following conditions:
10+
11+
The above copyright notice and this permission notice shall be included in
12+
all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
THE SOFTWARE.
21+
"""
22+
23+
from dataclasses import dataclass, replace
24+
from functools import singledispatch
25+
from typing import Optional
26+
27+
import numpy as np
28+
29+
from pytools import T, memoize_method
30+
31+
from arraycontext import PyOpenCLArrayContext
32+
from boxtree.tree import Tree
33+
from pytential import sym, GeometryCollection
34+
from pytential.linalg.utils import IndexList, TargetAndSourceClusterList
35+
36+
__doc__ = """
37+
Clustering
38+
~~~~~~~~~~
39+
40+
.. autoclass:: ClusterTreeLevel
41+
42+
.. autofunction:: cluster
43+
.. autofunction:: partition_by_nodes
44+
"""
45+
46+
# FIXME: this is just an arbitrary value
47+
_DEFAULT_MAX_PARTICLES_IN_BOX = 32
48+
49+
50+
# {{{ cluster tree
51+
52+
@dataclass(frozen=True)
53+
class ClusterTreeLevel:
54+
"""
55+
.. attribute:: level
56+
57+
Current level that is represented.
58+
59+
.. attribute:: nlevels
60+
61+
Total number of levels in the tree.
62+
63+
.. attribute:: nclusters
64+
65+
Number of clusters on the current level (same as number of boxes
66+
in :attr:`partition_box_ids`).
67+
68+
.. attribute:: partition_box_ids
69+
70+
Box IDs on the current level.
71+
72+
.. attribute:: partition_parent_ids
73+
74+
Parent box IDs for :attr:`partition_box_ids`.
75+
76+
.. attribute:: partition_parent_map
77+
78+
An object :class:`~numpy.ndarray`, where each entry maps a parent
79+
to its children indices in :attr:`partition_box_ids`. This can be used to
80+
:func:`cluster` all indices in ``partition_parent_map[i]`` into their
81+
parent.
82+
83+
.. automethod:: parent
84+
"""
85+
86+
# level info
87+
level: int
88+
partition_box_ids: np.ndarray
89+
90+
# tree info
91+
nlevels: int
92+
box_parent_ids: np.ndarray
93+
94+
# NOTE: only here to allow easier debugging + testing
95+
_tree: Optional[Tree]
96+
97+
@property
98+
def nclusters(self):
99+
return self.partition_box_ids.size
100+
101+
@property
102+
def partition_parent_ids(self):
103+
return self.box_parent_ids[self.partition_box_ids]
104+
105+
@property
106+
@memoize_method
107+
def partition_parent_map(self):
108+
# NOTE: np.unique returns a sorted array
109+
unique_parent_ids = np.unique(self.partition_parent_ids)
110+
# find the index of each parent id
111+
unique_parent_index = np.searchsorted(
112+
unique_parent_ids, self.partition_parent_ids
113+
)
114+
115+
unique_parent_map = np.empty(unique_parent_ids.size, dtype=object)
116+
for i in range(unique_parent_ids.size):
117+
unique_parent_map[i], = np.nonzero(unique_parent_index == i)
118+
119+
return unique_parent_map
120+
121+
def parent(self) -> "ClusterTreeLevel":
122+
"""
123+
:returns: a new :class:`ClusterTreeLevel` that represents the parent of
124+
the current one, with appropriately updated :attr:`partition_box_ids`,
125+
etc.
126+
"""
127+
128+
if self.nclusters == 1:
129+
assert self.level == 0
130+
return self
131+
132+
return replace(self,
133+
level=self.level - 1,
134+
partition_box_ids=np.unique(self.partition_parent_ids))
135+
136+
137+
@singledispatch
138+
def cluster(obj: T, ctree: ClusterTreeLevel) -> T:
139+
"""Merge together elements of *obj* into their parent object, as described
140+
by *ctree*.
141+
"""
142+
raise NotImplementedError(type(obj).__name__)
143+
144+
145+
@cluster.register(IndexList)
146+
def _cluster_index_list(obj: IndexList, ctree: ClusterTreeLevel) -> IndexList:
147+
assert obj.nclusters == ctree.nclusters
148+
149+
if ctree.nclusters == 1:
150+
return obj
151+
152+
sizes = [
153+
sum(obj.cluster_size(j) for j in ppm)
154+
for ppm in ctree.partition_parent_map
155+
]
156+
return replace(obj, starts=np.cumsum([0] + sizes))
157+
158+
159+
@cluster.register(TargetAndSourceClusterList)
160+
def _cluster_target_and_source_cluster_list(
161+
obj: TargetAndSourceClusterList, ctree: ClusterTreeLevel,
162+
) -> TargetAndSourceClusterList:
163+
assert obj.nclusters == ctree.nclusters
164+
165+
if ctree.nclusters == 1:
166+
return obj
167+
168+
return replace(obj,
169+
targets=cluster(obj.targets, ctree),
170+
sources=cluster(obj.sources, ctree))
171+
172+
# }}}
173+
174+
175+
# {{{ cluster generation
176+
177+
def _build_binary_ish_tree_from_indices(starts: np.ndarray) -> ClusterTreeLevel:
178+
partition_box_ids = np.arange(starts.size - 1)
179+
180+
box_ids = partition_box_ids
181+
182+
box_parent_ids = []
183+
offset = box_ids.size
184+
while box_ids.size > 1:
185+
# NOTE: this is probably not the most efficient way to do it, but this
186+
# code is mostly meant for debugging using a simple tree
187+
clusters = np.array_split(box_ids, box_ids.size // 2)
188+
parent_ids = offset + np.arange(len(clusters))
189+
box_parent_ids.append(np.repeat(parent_ids, [len(c) for c in clusters]))
190+
191+
box_ids = parent_ids
192+
offset += box_ids.size
193+
194+
# NOTE: make the root point to itself
195+
box_parent_ids.append(np.array([offset - 1]))
196+
nlevels = len(box_parent_ids)
197+
198+
return ClusterTreeLevel(
199+
level=nlevels - 1,
200+
partition_box_ids=partition_box_ids,
201+
nlevels=nlevels,
202+
box_parent_ids=np.concatenate(box_parent_ids),
203+
_tree=None)
204+
205+
206+
def partition_by_nodes(
207+
actx: PyOpenCLArrayContext, places: GeometryCollection, *,
208+
dofdesc: Optional[sym.DOFDescriptorLike] = None,
209+
tree_kind: Optional[str] = "adaptive-level-restricted",
210+
max_particles_in_box: Optional[int] = None) -> IndexList:
211+
"""Generate equally sized ranges of nodes. The partition is created at the
212+
lowest level of granularity, i.e. nodes. This results in balanced ranges
213+
of points, but will split elements across different ranges.
214+
215+
:arg dofdesc: a :class:`~pytential.symbolic.dof_desc.DOFDescriptor` for
216+
the geometry in *places* which should be partitioned.
217+
:arg tree_kind: if not *None*, it is passed to :class:`boxtree.TreeBuilder`.
218+
:arg max_particles_in_box: value used to control the number of points
219+
in each partition (and thus the number of partitions). See the documentation
220+
in :class:`boxtree.TreeBuilder`.
221+
"""
222+
if dofdesc is None:
223+
dofdesc = places.auto_source
224+
dofdesc = sym.as_dofdesc(dofdesc)
225+
226+
if max_particles_in_box is None:
227+
max_particles_in_box = _DEFAULT_MAX_PARTICLES_IN_BOX
228+
229+
lpot_source = places.get_geometry(dofdesc.geometry)
230+
discr = places.get_discretization(dofdesc.geometry, dofdesc.discr_stage)
231+
232+
if tree_kind is not None:
233+
from pytential.qbx.utils import tree_code_container
234+
tcc = tree_code_container(lpot_source._setup_actx)
235+
236+
from arraycontext import flatten
237+
from meshmode.dof_array import DOFArray
238+
tree, _ = tcc.build_tree()(actx.queue,
239+
particles=flatten(
240+
actx.thaw(discr.nodes()), actx, leaf_class=DOFArray
241+
),
242+
max_particles_in_box=max_particles_in_box,
243+
kind=tree_kind)
244+
245+
from boxtree import box_flags_enum
246+
tree = tree.get(actx.queue)
247+
leaf_boxes, = (tree.box_flags & box_flags_enum.HAS_CHILDREN == 0).nonzero()
248+
249+
indices = np.empty(len(leaf_boxes), dtype=object)
250+
starts = None
251+
252+
for i, ibox in enumerate(leaf_boxes):
253+
box_start = tree.box_source_starts[ibox]
254+
box_end = box_start + tree.box_source_counts_cumul[ibox]
255+
indices[i] = tree.user_source_ids[box_start:box_end]
256+
257+
ctree = ClusterTreeLevel(
258+
level=tree.nlevels - 1,
259+
nlevels=tree.nlevels,
260+
box_parent_ids=tree.box_parent_ids,
261+
partition_box_ids=leaf_boxes,
262+
_tree=tree)
263+
else:
264+
if discr.ambient_dim != 2 and discr.dim == 1:
265+
raise ValueError("only curves are supported for 'tree_kind=None'")
266+
267+
nclusters = max(discr.ndofs // max_particles_in_box, 2)
268+
indices = np.arange(0, discr.ndofs, dtype=np.int64)
269+
starts = np.linspace(0, discr.ndofs, nclusters + 1, dtype=np.int64)
270+
assert starts[-1] == discr.ndofs
271+
272+
ctree = _build_binary_ish_tree_from_indices(starts)
273+
274+
from pytential.linalg import make_index_list
275+
return make_index_list(indices, starts=starts), ctree
276+
277+
# }}}

0 commit comments

Comments
 (0)