@@ -354,7 +354,7 @@ def rechunk_for_cohorts(
354354def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
355355 """
356356 Rechunks array so that group boundaries line up with chunk boundaries, allowing
357- embarassingly parallel group reductions.
357+ embarrassingly parallel group reductions.
358358
359359 This only works when the groups are sequential
360360 (e.g. labels = ``[0,0,0,1,1,1,1,2,2]``).
@@ -849,7 +849,7 @@ def _finalize_results(
849849 """
850850 squeezed = _squeeze_results (results , axis )
851851
852- if agg .min_count is not None :
852+ if agg .min_count > 0 :
853853 counts = squeezed ["intermediates" ][- 1 ]
854854 squeezed ["intermediates" ] = squeezed ["intermediates" ][:- 1 ]
855855
@@ -860,7 +860,7 @@ def _finalize_results(
860860 else :
861861 finalized [agg .name ] = agg .finalize (* squeezed ["intermediates" ], ** agg .finalize_kwargs )
862862
863- if agg .min_count is not None :
863+ if agg .min_count > 0 :
864864 count_mask = counts < agg .min_count
865865 if count_mask .any ():
866866 # For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
@@ -1598,7 +1598,11 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:
15981598
15991599
16001600def _factorize_multiple (
1601- by : T_Bys , expected_groups : T_ExpectIndexTuple , any_by_dask : bool , reindex : bool
1601+ by : T_Bys ,
1602+ expected_groups : T_ExpectIndexTuple ,
1603+ any_by_dask : bool ,
1604+ reindex : bool ,
1605+ sort : bool = True ,
16021606) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
16031607 if any_by_dask :
16041608 import dask .array
@@ -1617,6 +1621,7 @@ def _factorize_multiple(
16171621 expected_groups = expected_groups ,
16181622 fastpath = True ,
16191623 reindex = reindex ,
1624+ sort = sort ,
16201625 )
16211626
16221627 fg , gs = [], []
@@ -1643,6 +1648,7 @@ def _factorize_multiple(
16431648 expected_groups = expected_groups ,
16441649 fastpath = True ,
16451650 reindex = reindex ,
1651+ sort = sort ,
16461652 )
16471653
16481654 return (group_idx ,), found_groups , grp_shape
@@ -1653,10 +1659,16 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
16531659 return (None ,) * nby
16541660
16551661 if nby == 1 and not isinstance (expected_groups , tuple ):
1656- if isinstance (expected_groups , pd .Index ):
1662+ if isinstance (expected_groups , ( pd .Index , np . ndarray ) ):
16571663 return (expected_groups ,)
16581664 else :
1659- return (np .asarray (expected_groups ),)
1665+ array = np .asarray (expected_groups )
1666+ if np .issubdtype (array .dtype , np .integer ):
1667+ # preserve default dtypes
1668+ # on pandas 1.5/2, on windows
1669+ # when a list is passed
1670+ array = array .astype (np .int64 )
1671+ return (array ,)
16601672
16611673 if nby > 1 and not isinstance (expected_groups , tuple ): # TODO: test for list
16621674 raise ValueError (
@@ -1833,21 +1845,28 @@ def groupby_reduce(
18331845 # (pd.IntervalIndex or not)
18341846 expected_groups = _convert_expected_groups_to_index (expected_groups , isbins , sort )
18351847
1836- is_binning = any ([ isinstance ( e , pd . IntervalIndex ) for e in expected_groups ])
1837-
1838- # TODO: could restrict this to dask-only
1839- factorize_early = ( nby > 1 ) or (
1840- is_binning and method == "cohorts" and is_duck_dask_array ( array )
1848+ # Don't factorize "early only when
1849+ # grouping by dask arrays, and not having expected_groups
1850+ factorize_early = not (
1851+ # can't do it if we are grouping by dask array but don't have expected_groups
1852+ any ( is_dask and ex_ is None for is_dask , ex_ in zip ( by_is_dask , expected_groups ) )
18411853 )
18421854 if factorize_early :
18431855 bys , final_groups , grp_shape = _factorize_multiple (
1844- bys , expected_groups , any_by_dask = any_by_dask , reindex = reindex
1856+ bys ,
1857+ expected_groups ,
1858+ any_by_dask = any_by_dask ,
1859+ # This is the only way it makes sense I think.
1860+ # reindex controls what's actually allocated in chunk_reduce
1861+ # At this point, we care about an accurate conversion to codes.
1862+ reindex = True ,
1863+ sort = sort ,
18451864 )
18461865 expected_groups = (pd .RangeIndex (math .prod (grp_shape )),)
18471866
18481867 assert len (bys ) == 1
1849- by_ = bys [ 0 ]
1850- expected_groups = expected_groups [ 0 ]
1868+ ( by_ ,) = bys
1869+ ( expected_groups ,) = expected_groups
18511870
18521871 if axis is None :
18531872 axis_ = tuple (array .ndim + np .arange (- by_ .ndim , 0 ))
@@ -1898,7 +1917,12 @@ def groupby_reduce(
18981917 min_count = 1
18991918
19001919 # TODO: set in xarray?
1901- if min_count is not None and func in ["nansum" , "nanprod" ] and fill_value is None :
1920+ if (
1921+ min_count is not None
1922+ and min_count > 0
1923+ and func in ["nansum" , "nanprod" ]
1924+ and fill_value is None
1925+ ):
19021926 # nansum, nanprod have fill_value=0, 1
19031927 # overwrite than when min_count is set
19041928 fill_value = np .nan
0 commit comments