@@ -214,34 +214,8 @@ def slices_from_chunks(chunks):
214214 return product (* slices )
215215
216216
217- @memoize
218- def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
219- """
220- Finds groups labels that occur together aka "cohorts"
221-
222- If available, results are cached in a 1MB cache managed by `cachey`.
223- This allows us to be quick when repeatedly calling groupby_reduce
224- for arrays with the same chunking (e.g. an xarray Dataset).
225-
226- Parameters
227- ----------
228- labels : np.ndarray
229- mD Array of integer group codes, factorized so that -1
230- represents NaNs.
231- chunks : tuple
232- chunks of the array being reduced
233- merge : bool, optional
234- Attempt to merge cohorts when one cohort's chunks are a subset
235- of another cohort's chunks.
236-
237- Returns
238- -------
239- cohorts: dict_values
240- Iterable of cohorts
241- """
242- # To do this, we must have values in memory so casting to numpy should be safe
243- labels = np .asarray (labels )
244-
217+ def _compute_label_chunk_bitmask (labels , chunks ):
218+ assert isinstance (labels , np .ndarray )
245219 shape = tuple (sum (c ) for c in chunks )
246220 nchunks = math .prod (len (c ) for c in chunks )
247221
@@ -271,6 +245,47 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
271245 cols_array = np .concatenate (cols )
272246 data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
273247 bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
248+
249+ return bitmask , nlabels , ilabels
250+
251+
252+ @memoize
253+ def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
254+ """
255+ Finds groups labels that occur together aka "cohorts"
256+
257+ If available, results are cached in a 1MB cache managed by `cachey`.
258+ This allows us to be quick when repeatedly calling groupby_reduce
259+ for arrays with the same chunking (e.g. an xarray Dataset).
260+
261+ Parameters
262+ ----------
263+ labels : np.ndarray
264+ mD Array of integer group codes, factorized so that -1
265+ represents NaNs.
266+ chunks : tuple
267+ chunks of the array being reduced
268+ merge : bool, optional
269+ Attempt to merge cohorts when one cohort's chunks are a subset
270+ of another cohort's chunks.
271+
272+ Returns
273+ -------
274+ cohorts: dict_values
275+ Iterable of cohorts
276+ """
277+ if not is_duck_array (labels ):
278+ labels = np .asarray (labels )
279+
280+ if is_duck_dask_array (labels ):
281+ import dask
282+
283+ ((bitmask , nlabels , ilabels ),) = dask .compute (
284+ dask .delayed (_compute_label_chunk_bitmask )(labels , chunks )
285+ )
286+ else :
287+ bitmask , nlabels , ilabels = _compute_label_chunk_bitmask (labels , chunks )
288+
274289 label_chunks = {
275290 lab : bitmask .indices [slice (bitmask .indptr [lab ], bitmask .indptr [lab + 1 ])]
276291 for lab in range (nlabels )
@@ -2039,9 +2054,6 @@ def groupby_reduce(
20392054 "Try engine='numpy' or engine='numba' instead."
20402055 )
20412056
2042- if method == "cohorts" and any_by_dask :
2043- raise ValueError (f"method={ method !r} can only be used when grouping by numpy arrays." )
2044-
20452057 reindex = _validate_reindex (
20462058 reindex , func , method , expected_groups , any_by_dask , is_duck_dask_array (array )
20472059 )
@@ -2076,6 +2088,12 @@ def groupby_reduce(
20762088 # can't do it if we are grouping by dask array but don't have expected_groups
20772089 any (is_dask and ex_ is None for is_dask , ex_ in zip (by_is_dask , expected_groups ))
20782090 )
2091+
2092+ if method == "cohorts" and not factorize_early :
2093+ raise ValueError (
2094+ "method='cohorts' can only be used when grouping by dask arrays if `expected_groups` is provided."
2095+ )
2096+
20792097 if factorize_early :
20802098 bys , final_groups , grp_shape = _factorize_multiple (
20812099 bys ,
0 commit comments