1616from mpisppy import MPI
1717from mpisppy.cylinders.spcommunicator import RecvArray, SendArray, SPCommunicator
1818from math import inf
19- from mpisppy.cylinders.spoke import ConvergerSpokeType
2019
2120from mpisppy import global_toc
2221
@@ -51,6 +50,8 @@ def __init__(self, spbase_object, fullcomm, strata_comm, cylinder_comm, communic
5150
5251 self.extension_recv = set()
5352
53+ self.initialize_bound_values()
54+
5455 return
5556
5657 @abc.abstractmethod
@@ -233,14 +234,12 @@ def receive_innerbounds(self):
233234 (but should be harmless to call if there are none)
234235 """
235236 logging.debug("Hub is trying to receive from InnerBounds")
236- for idx in self.innerbound_spoke_indices:
237- key = self._make_key(Field.OBJECTIVE_INNER_BOUND, idx)
238- recv_buf = self.receive_buffers[key]
237+ for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND]:
239238 is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_INNER_BOUND)
240239 if is_new:
241240 bound = recv_buf[0]
242241 logging.debug("!! new InnerBound to opt {}".format(bound))
243- self.BestInnerBound = self.InnerBoundUpdate(bound, idx)
242+ self.BestInnerBound = self.InnerBoundUpdate(bound, cls, idx)
244243 logging.debug("ph back from InnerBounds")
245244
246245 def receive_outerbounds(self):
@@ -249,37 +248,35 @@ def receive_outerbounds(self):
249248 (but should be harmless to call if there are none)
250249 """
251250 logging.debug("Hub is trying to receive from OuterBounds")
252- for idx in self.outerbound_spoke_indices:
253- key = self._make_key(Field.OBJECTIVE_OUTER_BOUND, idx)
254- recv_buf = self.receive_buffers[key]
251+ for idx, cls, recv_buf in self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND]:
255252 is_new = self.hub_from_spoke(recv_buf, idx, Field.OBJECTIVE_OUTER_BOUND)
256253 if is_new:
257254 bound = recv_buf[0]
258255 logging.debug("!! new OuterBound to opt {}".format(bound))
259- self.BestOuterBound = self.OuterBoundUpdate(bound, idx)
256+ self.BestOuterBound = self.OuterBoundUpdate(bound, cls, idx)
260257 logging.debug("ph back from OuterBounds")
261258
262- def OuterBoundUpdate(self, new_bound, idx=None, char='*'):
259+ def OuterBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
263260 current_bound = self.BestOuterBound
264261 if self._outer_bound_update(new_bound, current_bound):
265- if idx is None:
262+ if cls is None:
266263 self.latest_ob_char = char
267264 self.last_ob_idx = 0
268265 else:
269- self.latest_ob_char = self.outerbound_spoke_chars[idx]
266+ self.latest_ib_char = cls.converger_spoke_char
270267 self.last_ob_idx = idx
271268 return new_bound
272269 else:
273270 return current_bound
274271
275- def InnerBoundUpdate(self, new_bound, idx=None, char='*'):
272+ def InnerBoundUpdate(self, new_bound, cls=None, idx=None, char='*'):
276273 current_bound = self.BestInnerBound
277274 if self._inner_bound_update(new_bound, current_bound):
278- if idx is None:
275+ if cls is None:
279276 self.latest_ib_char = char
280277 self.last_ib_idx = 0
281278 else:
282- self.latest_ib_char = self.innerbound_spoke_chars[idx]
279+ self.latest_ib_char = cls.converger_spoke_char
283280 self.last_ib_idx = idx
284281 return new_bound
285282 else:
@@ -297,28 +294,6 @@ def initialize_bound_values(self):
297294 self._inner_bound_update = lambda new, old : (new > old)
298295 self._outer_bound_update = lambda new, old : (new < old)
299296
300- def initialize_outer_bound_buffers(self):
301- """ Initialize outer bound receive buffers
302- """
303- self.outerbound_receive_buffers = dict()
304- for idx in self.outerbound_spoke_indices:
305- self.outerbound_receive_buffers[idx] = self.register_recv_field(
306- Field.OBJECTIVE_OUTER_BOUND, idx, 1,
307- )
308- ## End for
309- return
310-
311- def initialize_inner_bound_buffers(self):
312- """ Initialize inner bound receive buffers
313- """
314- self.innerbound_receive_buffers = dict()
315- for idx in self.innerbound_spoke_indices:
316- self.innerbound_receive_buffers[idx] = self.register_recv_field(
317- Field.OBJECTIVE_INNER_BOUND, idx, 1
318- )
319- ## End for
320- return
321-
322297 def _populate_boundsout_cache(self, buf):
323298 """ Populate a given buffer with the current bounds
324299 """
@@ -327,62 +302,26 @@ def _populate_boundsout_cache(self, buf):
327302
328303 def send_boundsout(self):
329304 """ Send bounds to the appropriate spokes
330- This is called only for spokes which are bounds only.
331- w and nonant spokes are passed bounds through the w and nonant buffers
332305 """
333306 my_bounds = self.send_buffers[Field.BEST_OBJECTIVE_BOUNDS]
334307 self._populate_boundsout_cache(my_bounds.array())
335308 logging.debug("hub is sending bounds={}".format(my_bounds))
336309 self.hub_to_spoke(my_bounds, Field.BEST_OBJECTIVE_BOUNDS)
337310 return
338311
339- def initialize_spoke_indices (self):
312+ def register_receive_fields (self):
340313 """ Figure out what types of spokes we have,
341314 and sort them into the appropriate classes.
342315
343316 Note:
344317 Some spokes may be multiple types (e.g. outerbound and nonant),
345318 though not all combinations are supported.
346319 """
347- self.outerbound_spoke_indices = set()
348- self.innerbound_spoke_indices = set()
349- self.nonant_spoke_indices = set()
350- self.w_spoke_indices = set()
351-
352- self.outerbound_spoke_chars = dict()
353- self.innerbound_spoke_chars = dict()
354-
355- for (i, spoke) in enumerate(self.communicators):
356- if i == self.strata_rank:
357- continue
358- spoke_class = spoke["spcomm_class"]
359- if hasattr(spoke_class, "converger_spoke_types"):
360- for cst in spoke_class.converger_spoke_types:
361- if cst == ConvergerSpokeType.OUTER_BOUND:
362- self.outerbound_spoke_indices.add(i)
363- self.outerbound_spoke_chars[i] = spoke_class.converger_spoke_char
364- elif cst == ConvergerSpokeType.INNER_BOUND:
365- self.innerbound_spoke_indices.add(i)
366- self.innerbound_spoke_chars[i] = spoke_class.converger_spoke_char
367- elif cst == ConvergerSpokeType.W_GETTER:
368- self.w_spoke_indices.add(i)
369- elif cst == ConvergerSpokeType.NONANT_GETTER:
370- self.nonant_spoke_indices.add(i)
371- else:
372- raise RuntimeError(f"Unrecognized converger_spoke_type {cst}")
373-
374- else: ##this isn't necessarily wrong, i.e., cut generators
375- logger.debug(f"Spoke class {spoke_class} not recognized by hub")
376-
377- # all _BoundSpoke spokes get hub bounds so we determine which spokes
378- # are "bounds only"
379- self.bounds_only_indices = \
380- (self.outerbound_spoke_indices | self.innerbound_spoke_indices) - \
381- (self.w_spoke_indices | self.nonant_spoke_indices)
320+ super().register_receive_fields()
382321
383322 # Not all opt classes may have extensions
384323 if getattr(self.opt, "extensions", None) is not None:
385- self.opt.extobject.initialize_spoke_indices ()
324+ self.opt.extobject.register_receive_fields ()
386325
387326 return
388327
@@ -511,31 +450,14 @@ def setup_hub(self):
511450 "Cannot call setup_hub before memory windows are constructed"
512451 )
513452
514- self.initialize_spoke_indices()
515- self.initialize_bound_values()
516-
517- self.initialize_outer_bound_buffers()
518- self.initialize_inner_bound_buffers()
519-
520- ## Do some checking for things we currently don't support
521- if len(self.outerbound_spoke_indices & self.innerbound_spoke_indices) > 0:
522- raise RuntimeError(
523- "A Spoke providing both inner and outer "
524- "bounds is currently unsupported"
525- )
526- if len(self.w_spoke_indices & self.nonant_spoke_indices) > 0:
527- raise RuntimeError(
528- "A Spoke needing both Ws and nonants is currently unsupported"
529- )
530-
531453 ## Generate some warnings if nothing is giving bounds
532- if not self.outerbound_spoke_indices :
454+ if not self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND] :
533455 logger.warn(
534456 "No OuterBound Spokes defined, this converger "
535457 "will not cause the hub to terminate"
536458 )
537459
538- if not self.innerbound_spoke_indices :
460+ if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND] :
539461 logger.warn(
540462 "No InnerBound Spokes defined, this converger "
541463 "will not cause the hub to terminate"
@@ -578,7 +500,7 @@ def is_converged(self):
578500 if self.opt.best_bound_obj_val is not None:
579501 self.BestOuterBound = self.OuterBoundUpdate(self.opt.best_bound_obj_val)
580502
581- if not self.innerbound_spoke_indices :
503+ if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND] :
582504 if self.opt._PHIter == 1:
583505 logger.warning(
584506 "PHHub cannot compute convergence without "
@@ -591,7 +513,7 @@ def is_converged(self):
591513
592514 return False
593515
594- if not self.outerbound_spoke_indices :
516+ if not self.receive_field_spcomms[Field.OBJECTIVE_OUTER_BOUND] :
595517 if self.opt._PHIter == 1 and not self._hub_algo_best_bound_provider:
596518 global_toc(
597519 "Without outer bound spokes, no progress "
@@ -660,24 +582,8 @@ def setup_hub(self):
660582 "Cannot call setup_hub before memory windows are constructed"
661583 )
662584
663- self.initialize_spoke_indices()
664- self.initialize_bound_values()
665-
666- self.initialize_outer_bound_buffers()
667- self.initialize_inner_bound_buffers()
668-
669- ## Do some checking for things we currently
670- ## do not support
671- if self.w_spoke_indices:
672- raise RuntimeError("LShaped hub does not compute dual weights (Ws)")
673- if len(self.outerbound_spoke_indices & self.innerbound_spoke_indices) > 0:
674- raise RuntimeError(
675- "A Spoke providing both inner and outer "
676- "bounds is currently unsupported"
677- )
678-
679585 ## Generate some warnings if nothing is giving bounds
680- if not self.innerbound_spoke_indices :
586+ if not self.receive_field_spcomms[Field.OBJECTIVE_INNER_BOUND] :
681587 logger.warn(
682588 "No InnerBound Spokes defined, this converger "
683589 "will not cause the hub to terminate"
0 commit comments