diff --git a/asQ/ensemble.py b/asQ/ensemble.py index cd774a63..fd50b06c 100644 --- a/asQ/ensemble.py +++ b/asQ/ensemble.py @@ -1,7 +1,8 @@ +import weakref from firedrake import COMM_WORLD, Ensemble -from pyop2.mpi import internal_comm +from pyop2.mpi import MPI, internal_comm, is_pyop2_comm, PyOP2CommError -__all__ = ['create_ensemble', 'split_ensemble', 'EnsembleConnector'] +__all__ = ['create_ensemble', 'split_ensemble'] def create_ensemble(time_partition, comm=COMM_WORLD): @@ -42,35 +43,82 @@ def split_ensemble(ensemble, split_size, **kwargs): split_rank = ensemble.ensemble_comm.rank // split_size # create split_ensemble.global_comm - split_comm = ensemble.global_comm.Split(color=split_rank, - key=ensemble.global_comm.rank) + split_global_comm = ensemble.global_comm.Split(color=split_rank, + key=ensemble.global_comm.rank) - return EnsembleConnector(split_comm, ensemble.comm, split_size, **kwargs) + # create split_ensemble.ensemble_comm + split_ensemble_comm = ensemble.ensemble_comm.Split(color=split_rank, + key=ensemble.ensemble_comm.rank) + new_ensemble = ManualEnsemble(split_global_comm, ensemble.comm, split_ensemble_comm, **kwargs) -class EnsembleConnector(Ensemble): - def __init__(self, global_comm, local_comm, nmembers, **kwargs): + # make sure the new comms are cleaned up when the split ensemble goes out of scope + weakref.finalize(new_ensemble, split_global_comm.Free) + weakref.finalize(new_ensemble, split_ensemble_comm.Free) + + return new_ensemble + + +class ManualEnsemble(Ensemble): + def __init__(self, global_comm, spatial_comm, ensemble_comm, **kwargs): """ - An Ensemble created from provided spatial communicators (ensemble.comm). + An Ensemble created from provided comms. :arg global_comm: global communicator the Ensemble is defined over. - :arg local_comm: communicator to use for the Ensemble.comm member. - :arg nmembers: number of Ensemble members (ensemble.ensemble_comm.size). + :arg spatial_comm: communicator to use for the Ensemble.comm member. + :arg ensemble_comm: communicator to use for the Ensemble.ensemble_comm member. + + The global_comm, spatial_comm, and ensemble_comm must have the same logical meaning + as they do in firedrake.Ensemble. i.e. the global_comm is the union of a cartesian + product of multiple spatial_comms and ensemble_comms. + - ManualEnsemble is logically defined over all ranks in global_comm. + - Each rank in global_comm belongs to only one spatial_comm and one ensemble_comm. + - The size of the intersection of any (spatial_comm, ensemble_comm) pair is 1. + + WARNING: Not meeting these requirements may produce in errors, hangs, and nonsensical results. + + ManualEnsemble will not Free any of the comms. This is the responsibility of the user. """ - if nmembers*local_comm.size != global_comm.size: - msg = "The global ensemble must have the same number of ranks as the sum of the local comms" - raise ValueError(msg) + # are we handed user comms? + + for comm in (global_comm, spatial_comm, ensemble_comm): + if is_pyop2_comm(comm): + raise PyOP2CommError("Cannot construct Ensemble from PyOP2 internal comm") + + # check cartesian product consistency + + if spatial_comm.size*ensemble_comm.size != global_comm.size: + msg = "The global comm must have the same number of ranks as the product of spatial and ensemble comms" + raise PyOP2CommError(msg) + + global_group = global_comm.Get_group() + spatial_group = spatial_comm.Get_group() + ensemble_group = ensemble_comm.Get_group() + + if MPI.Group.Intersection(spatial_group, ensemble_group).size != 1: + raise PyOP2CommError("spatial and ensemble comms must be cartesian product in global_comm") + + is_subgroup = lambda sub, group: MPI.Group.Compare(sub, MPI.Group.Intersection(sub, group)) in {MPI.IDENT, MPI.CONGRUENT} + + if not is_subgroup(spatial_group, global_group): + raise PyOP2CommError("spatial_comm must be subgroup of global_comm") + if not is_subgroup(ensemble_group, global_group): + raise PyOP2CommError("ensemble_comm must be subgroup of global_comm") + + # create internal duplicates and name comms for debugging + ensemble_name = kwargs.get("name", "Ensemble") - ensemble_name = kwargs.get("ensemble_name", "Ensemble") self.global_comm = global_comm + if not hasattr(self.global_comm, "name"): + self.global_comm.name = f"{ensemble_name} global comm" self._comm = internal_comm(self.global_comm, self) - self.comm = local_comm - self.comm.name = f"{ensemble_name} spatial comm" + self.comm = spatial_comm + if not hasattr(self.comm, "name"): + self.comm.name = f"{ensemble_name} spatial comm" self._spatial_comm = internal_comm(self.comm, self) - self.ensemble_comm = self.global_comm.Split(color=self.comm.rank, - key=global_comm.rank) - self.ensemble_comm.name = f"{ensemble_name} ensemble comm" - + self.ensemble_comm = ensemble_comm + if not hasattr(self.ensemble_comm, "name"): + self.ensemble_comm.name = f"{ensemble_name} ensemble comm" self._ensemble_comm = internal_comm(self.ensemble_comm, self)