22 datetime ,
33 timedelta ,
44)
5- from typing import Any
65
76import numpy as np
87import pytest
@@ -1949,54 +1948,10 @@ def test_rolling_timedelta_window_non_nanoseconds(unit, tz):
19491948 tm .assert_frame_equal (ref_df , df )
19501949
19511950
1952- class StandardWindowIndexer (BaseIndexer ):
1953- def __init__ (self , n , win_len ):
1954- self .n = n
1955- self .win_len = win_len
1956- super ().__init__ ()
1957-
1958- def get_window_bounds (
1959- self , num_values = None , min_periods = None , center = None , closed = None , step = None
1960- ):
1961- if num_values is None :
1962- num_values = self .n
1963- end = np .arange (num_values , dtype = "int64" ) + 1
1964- start = np .clip (end - self .win_len , 0 , num_values )
1965- return start , end
1966-
1967-
1968- class CustomLengthWindowIndexer (BaseIndexer ):
1969- def __init__ (self , rnd : np .random .Generator , n , win_len ):
1970- self .window = rnd .integers (win_len , size = n )
1971- super ().__init__ ()
1972-
1973- def get_window_bounds (
1974- self , num_values = None , min_periods = None , center = None , closed = None , step = None
1975- ):
1976- if num_values is None :
1977- num_values = len (self .window )
1978- end = np .arange (num_values , dtype = "int64" ) + 1
1979- start = np .clip (end - self .window , 0 , num_values )
1980- return start , end
1981-
1982-
1983- class ArbitraryWindowIndexer (BaseIndexer ):
1984- def __init__ (self , rnd : np .random .Generator , n , win_len ):
1985- start = rnd .integers (n , size = n )
1986- win_len = rnd .integers (win_len , size = n )
1987- end = np .where (start - win_len >= 0 , start - win_len , start + win_len )
1988-
1989- (start , end ) = (
1990- np .where (end >= start , start , end ),
1991- np .where (end >= start , end , start ),
1992- )
1993-
1994- # It is extremely unlikely that a random array would come sorted,
1995- # so we proceed with sort without checking if it is sorted.
1996- prm = sorted (range (len (start )), key = lambda i : (end [i ], start [i ]))
1997-
1998- self ._start = np .array (start )[prm ]
1999- self ._end = np .array (end )[prm ]
1951+ class PrescribedWindowIndexer (BaseIndexer ):
1952+ def __init__ (self , start , end ):
1953+ self ._start = start
1954+ self ._end = end
20001955 super ().__init__ ()
20011956
20021957 def get_window_bounds (
@@ -2010,109 +1965,46 @@ def get_window_bounds(
20101965
20111966
20121967class TestMinMax :
2013- # Pytest cache will not be a good choice here, because it appears
2014- # pytest persists data on disk, and we are not really interested
2015- # in flooding your hard drive with random numbers.
2016- # Thus we just cache control data in memory to avoid repetititve calculations.
2017- class Cache :
2018- def __init__ (self ) -> None :
2019- self .ctrl : dict [Any , Any ] = {}
2020-
2021- @pytest .fixture (scope = "class" )
2022- def cache (self ) -> Cache :
2023- return self .Cache ()
2024-
2025- @pytest .mark .parametrize ("is_max" , [True , False ])
2026- # @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
2027- @pytest .mark .parametrize ("engine" , ["cython" ])
2028- @pytest .mark .parametrize (
2029- "seed, n, win_len, min_obs, frac_nan, indexer_t" ,
2030- [
2031- (42 , 1000 , 80 , 15 , 0.3 , CustomLengthWindowIndexer ),
2032- (52 , 1000 , 80 , 15 , 0.3 , ArbitraryWindowIndexer ),
2033- (1984 , 1000 , 40 , 25 , 0.3 , None ),
2034- ],
2035- )
2036- def test_minmax (
2037- self , is_max , engine , seed , n , win_len , min_obs , frac_nan , indexer_t , cache
2038- ):
2039- if seed is not None and isinstance (seed , np .random ._generator .Generator ):
2040- rng = np .random .default_rng (seed )
2041- rng .bit_generator .state = seed .bit_generator .state
2042- else :
2043- rng = np .random .default_rng (seed )
2044-
2045- if seed is None or isinstance (seed , np .random ._generator .Generator ):
2046- rng_state_for_key = (
2047- rng .bit_generator .state ["bit_generator" ],
2048- rng .bit_generator .state ["state" ]["state" ],
2049- rng .bit_generator .state ["state" ]["inc" ],
2050- rng .bit_generator .state ["has_uint32" ],
2051- rng .bit_generator .state ["uinteger" ],
2052- )
2053- else :
2054- rng_state_for_key = seed
2055- self .last_rng_state = rng .bit_generator .state
2056- vals = DataFrame ({"Data" : rng .random (n )})
2057- if frac_nan > 0 :
2058- is_nan = rng .random (len (vals )) < frac_nan
2059- vals .Data = np .where (is_nan , np .nan , vals .Data )
2060-
2061- ind_obj = (
2062- indexer_t (rng , len (vals ), win_len )
2063- if indexer_t
2064- else StandardWindowIndexer (len (vals ), win_len )
2065- )
2066- ind_param = ind_obj if indexer_t else win_len
1968+ TestData = [
1969+ (True , False , [3.0 , 5.0 , 2.0 , 5.0 , 1.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 ]),
1970+ (True , True , [3.0 , 4.0 , 2.0 , 4.0 , 1.0 , 4.0 , 6.0 , 7.0 , 7.0 , 9.0 ]),
1971+ (False , False , [3.0 , 2.0 , 2.0 , 1.0 , 1.0 , 0.0 , 0.0 , 0.0 , 7.0 , 0.0 ]),
1972+ (False , True , [3.0 , 2.0 , 2.0 , 1.0 , 1.0 , 1.0 , 6.0 , 6.0 , 7.0 , 1.0 ]),
1973+ ]
20671974
2068- (start , end ) = ind_obj .get_window_bounds ()
2069- ctrl_key = (is_max , rng_state_for_key , n , win_len , min_obs , frac_nan , indexer_t )
2070- if ctrl_key in cache .ctrl :
2071- ctrl = cache .ctrl [ctrl_key ]
1975+ @pytest .mark .parametrize ("is_max, has_nan, exp_list" , TestData )
1976+ def test_minmax (self , is_max , has_nan , exp_list , engine = None ):
1977+ nan_idx = [0 , 5 , 8 ]
1978+ df = DataFrame (
1979+ {
1980+ "data" : [5.0 , 4.0 , 3.0 , 2.0 , 1.0 , 0.0 , 6.0 , 7.0 , 8.0 , 9.0 ],
1981+ "start" : [2 , 0 , 3 , 0 , 4 , 0 , 5 , 5 , 7 , 3 ],
1982+ "end" : [3 , 4 , 4 , 5 , 5 , 6 , 7 , 8 , 9 , 10 ],
1983+ }
1984+ )
1985+ if has_nan :
1986+ df .loc [nan_idx , "data" ] = np .nan
1987+ expected = Series (exp_list , name = "data" )
1988+ r = df .data .rolling (
1989+ PrescribedWindowIndexer (df .start .to_numpy (), df .end .to_numpy ())
1990+ )
1991+ if is_max :
1992+ result = r .max (engine = engine )
20721993 else :
2073- # This is brute force calculation, and may get expensive when n is
2074- # large, so we cache it.
2075- ctrl = calc_minmax_control (vals .Data , start , end , min_obs , is_max )
2076- cache .ctrl [ctrl_key ] = ctrl
2077-
2078- r = vals .rolling (ind_param , min_periods = min_obs )
2079- f = r .max if is_max else r .min
2080- test = f (engine = engine )
2081- tm .assert_series_equal (test .Data , ctrl .Data )
2082-
2083- # @pytest.mark.parametrize("engine", ["python", "cython", "numba"])
2084- @pytest .mark .parametrize ("engine" , ["cython" ])
2085- @pytest .mark .parametrize (
2086- "seed, n, win_len, indexer_t" ,
2087- [
2088- (42 , 15 , 7 , ArbitraryWindowIndexer ),
2089- ],
2090- )
2091- def test_wrong_order (self , engine , seed , n , win_len , indexer_t ):
2092- rng = np .random .default_rng (seed )
2093- vals = DataFrame ({"Data" : rng .random (n )})
1994+ result = r .min (engine = engine )
1995+
1996+ tm .assert_series_equal (result , expected )
1997+
1998+ def test_wrong_order (self , engine = None ):
1999+ start = np .array (range (5 ), dtype = np .int64 )
2000+ end = start + 1
2001+ end [3 ] = end [2 ]
2002+ start [3 ] = start [2 ] - 1
20942003
2095- ind_obj = indexer_t (rng , len (vals ), win_len )
2096- ind_obj ._end [[14 , 7 ]] = ind_obj ._end [[7 , 14 ]]
2004+ df = DataFrame ({"data" : start * 1.0 , "start" : start , "end" : end })
20972005
2098- f = vals . rolling (ind_obj ). max
2006+ r = df . data . rolling (PrescribedWindowIndexer ( start , end ))
20992007 with pytest .raises (
2100- ValueError , match = "Start/End ordering requirement is violated at index 8 "
2008+ ValueError , match = "Start/End ordering requirement is violated at index 3 "
21012009 ):
2102- f (engine = engine )
2103-
2104-
2105- def calc_minmax_control (vals , start , end , min_periods , ismax ):
2106- func = np .nanmax if ismax else np .nanmin
2107- outp = np .full (vals .shape , np .nan )
2108- for i in range (len (start )):
2109- if start [i ] >= end [i ]:
2110- outp [i ] = np .nan
2111- else :
2112- rng = vals [start [i ] : end [i ]]
2113- non_nan_cnt = np .count_nonzero (~ np .isnan (rng ))
2114- if non_nan_cnt >= min_periods :
2115- outp [i ] = func (rng )
2116- else :
2117- outp [i ] = np .nan
2118- return DataFrame ({"Data" : outp })
2010+ r .max (engine = engine )
0 commit comments