1010import boost_histogram .storage as storage
1111import dask
1212import dask .array as da
13- from dask .bag .core import empty_safe_aggregate , partition_all
1413from dask .base import DaskMethodsMixin , dont_optimize , is_dask_collection , tokenize
1514from dask .context import globalmethod
1615from dask .delayed import Delayed , delayed
2019from tlz import first
2120
2221from dask_histogram .bins import normalize_bins_range
23- from dask_histogram .core import AggHistogram , _get_optimization_function , factory
22+ from dask_histogram .core import (
23+ AggHistogram ,
24+ _get_optimization_function ,
25+ _partitioned_histogram_multifill ,
26+ _reduction ,
27+ )
2428
2529if TYPE_CHECKING :
2630 from dask_histogram .typing import (
3640__all__ = ("Histogram" , "histogram" , "histogram2d" , "histogramdd" )
3741
3842
39- def _build_staged_tree_reduce (
40- stages : list [AggHistogram ], split_every : int | bool
41- ) -> HighLevelGraph :
42- if not split_every :
43- split_every = len (stages )
44-
45- reducer = sum
46-
47- token = tokenize (stages , reducer , split_every )
48-
49- k = len (stages )
50- b = ""
51- fmt = f"staged-fill-aggregate-{ token } "
52- depth = 0
53-
54- dsk = {}
55-
56- if k > 1 :
57- while k > split_every :
58- c = fmt + str (depth )
59- for i , inds in enumerate (partition_all (split_every , range (k ))):
60- dsk [(c , i )] = (
61- empty_safe_aggregate ,
62- reducer ,
63- [
64- (stages [j ].name if depth == 0 else b , 0 if depth == 0 else j )
65- for j in inds
66- ],
67- False ,
68- )
69-
70- k = i + 1
71- b = c
72- depth += 1
73-
74- dsk [(fmt , 0 )] = (
75- empty_safe_aggregate ,
76- reducer ,
77- [
78- (stages [j ].name if depth == 0 else b , 0 if depth == 0 else j )
79- for j in range (k )
80- ],
81- True ,
82- )
83- return fmt , HighLevelGraph .from_collections (fmt , dsk , dependencies = stages )
84-
85- return stages [0 ].name , stages [0 ].dask
86-
87-
8843class Histogram (bh .Histogram , DaskMethodsMixin , family = dask_histogram ):
8944 """Histogram object capable of lazy computation.
9045
@@ -97,9 +52,6 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
9752 type is :py:class:`boost_histogram.storage.Double`.
9853 metadata : Any
9954 Data that is passed along if a new histogram is created.
100- split_every : int | bool | None, default None
101- Width of aggregation layers for staged fills.
102- If False, all staged fills are added in one layer (memory intensive!).
10355
10456 See Also
10557 --------
@@ -139,7 +91,7 @@ def __init__(
13991 ) -> None :
14092 """Construct a Histogram object."""
14193 super ().__init__ (* axes , storage = storage , metadata = metadata )
142- self ._staged : list [ AggHistogram ] | None = None
94+ self ._staged : AggHistogram | None = None
14395 self ._dask_name : str | None = (
14496 f"empty-histogram-{ tokenize (* axes , storage , metadata )} "
14597 )
@@ -148,29 +100,19 @@ def __init__(
148100 {},
149101 )
150102 self ._split_every = split_every
151- if self ._split_every is None :
152- self ._split_every = dask .config .get ("histogram.aggregation.split_every" , 8 )
153103
154104 @property
155105 def _histref (self ):
156106 return (
157107 tuple (self .axes ),
158- self .storage_type ,
108+ self .storage_type () ,
159109 self .metadata ,
160110 )
161111
162112 def __iadd__ (self , other ):
163- if self .staged_fills () and other .staged_fills ():
164- self ._staged += other ._staged
165- elif not self .staged_fills () and other .staged_fills ():
166- self ._staged = other ._staged
167- if self .staged_fills ():
168- new_name , new_graph = _build_staged_tree_reduce (
169- self ._staged , self ._split_every
170- )
171- self ._dask = new_graph
172- self ._dask_name = new_name
173- return self
113+ raise NotImplementedError (
114+ "dask-boost-histograms are not addable, please sum them after computation!"
115+ )
174116
175117 def __add__ (self , other ):
176118 return self .__iadd__ (other )
@@ -234,6 +176,8 @@ def _in_memory_type(self) -> type[bh.Histogram]:
234176
235177 @property
236178 def dask_name (self ) -> str :
179+ if self ._dask_name == "__not_yet_calculated__" and self ._dask is None :
180+ self ._build_taskgraph ()
237181 if self ._dask_name is None :
238182 raise RuntimeError (
239183 "The dask name should never be None when it's requested."
@@ -242,12 +186,45 @@ def dask_name(self) -> str:
242186
243187 @property
244188 def dask (self ) -> HighLevelGraph :
189+ if self ._dask_name == "__not_yet_calculated__" and self ._dask is None :
190+ self ._build_taskgraph ()
245191 if self ._dask is None :
246192 raise RuntimeError (
247193 "The dask graph should never be None when it's requested."
248194 )
249195 return self ._dask
250196
197+ def _build_taskgraph (self ):
198+ data_list = []
199+ weights = []
200+ samples = []
201+
202+ for afill in self ._staged :
203+ data_list .append (afill ["args" ])
204+ weights .append (afill ["kwargs" ]["weight" ])
205+ samples .append (afill ["kwargs" ]["sample" ])
206+
207+ if all (weight is None for weight in weights ):
208+ weights = None
209+
210+ if not all (sample is None for sample in samples ):
211+ samples = None
212+
213+ split_every = self ._split_every or dask .config .get (
214+ "histogram.aggregation.split-every" , 8
215+ )
216+
217+ fills = _partitioned_histogram_multifill (
218+ data_list , self ._histref , weights , samples
219+ )
220+
221+ output_hist = _reduction (fills , split_every )
222+
223+ self ._staged = None
224+ self ._staged_result = output_hist
225+ self ._dask = output_hist .dask
226+ self ._dask_name = output_hist .name
227+
251228 def fill ( # type: ignore
252229 self ,
253230 * args : DaskCollection ,
@@ -318,14 +295,13 @@ def fill( # type: ignore
318295 else :
319296 raise ValueError (f"Cannot interpret input data: { args } " )
320297
321- new_fill = factory ( * args , histref = self . _histref , weights = weight , sample = sample )
298+ new_fill = { " args" : args , "kwargs" : { " weight" : weight , " sample" : sample }}
322299 if self ._staged is None :
323300 self ._staged = [new_fill ]
324301 else :
325- self ._staged += [new_fill ]
326- new_name , new_graph = _build_staged_tree_reduce (self ._staged , self ._split_every )
327- self ._dask = new_graph
328- self ._dask_name = new_name
302+ self ._staged .append (new_fill )
303+ self ._dask = None
304+ self ._dask_name = "__not_yet_calculated__"
329305
330306 return self
331307
@@ -383,7 +359,8 @@ def to_delayed(self) -> Delayed:
383359
384360 """
385361 if self ._staged is not None :
386- return sum (self ._staged [1 :], start = self ._staged [0 ]).to_delayed ()
362+ self ._build_taskgraph ()
363+ return self ._staged_result .to_delayed ()
387364 return delayed (bh .Histogram (self ))
388365
389366 def __repr__ (self ) -> str :
@@ -449,7 +426,8 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any:
449426
450427 """
451428 if self ._staged is not None :
452- return sum (self ._staged ).to_dask_array (flow = flow , dd = dd )
429+ self ._build_taskgraph ()
430+ return self ._staged_result .to_dask_array (flow = flow , dd = dd )
453431 else :
454432 counts , edges = self .to_numpy (flow = flow , dd = True , view = False )
455433 counts = da .from_array (counts )
0 commit comments