Skip to content

Commit 8fdb3d4

Browse files
authored
Fixed #576 incorrect stimp, stimped, gpu_stimp normalize rerouting (#577)
* Made excl_zone optional in _multi_distance_profile funcs * Fixed stimp, stimped, gpu_stimp normalized reroute
1 parent e58f533 commit 8fdb3d4

File tree

6 files changed

+107
-19
lines changed

6 files changed

+107
-19
lines changed

stumpy/aamp_stimp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(
182182
p : float, default 2.0
183183
The p-norm to apply for computing the Minkowski distance.
184184
"""
185-
self._T = T
185+
self._T = T.copy()
186186
self._T_min = np.min(self._T[np.isfinite(self._T)])
187187
self._T_max = np.max(self._T[np.isfinite(self._T)])
188188
self._p = p

stumpy/core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,17 @@ def non_normalized(non_norm, exclude=None, replace=None):
8181
parameters when necessary.
8282
8383
```
84-
def non_norm_func(Q, T, A):
84+
def non_norm_func(Q, T, A_non_norm):
8585
...
8686
return
8787
8888
8989
@non_normalized(
9090
non_norm_func,
91-
exclude=["normalize", "A", "B"],
92-
replace={"A": None},
91+
exclude=["normalize", "p", "A", "B"],
92+
replace={"A_norm": "A_non_norm", "other_norm": None},
9393
)
94-
def norm_func(Q, T, B=None, normalize=True):
94+
def norm_func(Q, T, A_norm=None, other_norm=None, normalize=True, p=2.0):
9595
...
9696
return
9797
```
@@ -104,13 +104,16 @@ def norm_func(Q, T, B=None, normalize=True):
104104
105105
exclude : list, default None
106106
A list of function (or class) parameter names to exclude when comparing the
107-
function (or class) signatures
107+
function (or class) signatures. When `exlcude is None`, this parameter is
108+
automatically set to `exclude = ["normalize", "p"]` by default.
108109
109110
replace : dict, default None
110111
A dictionary of function (or class) parameter key-value pairs. Each key that
111112
is found as a parameter name in the `norm` function (or class) will be replaced
112113
by its corresponding or complementary parameter name in the `non_norm` function
113-
(or class).
114+
(or class) (e.g., {"norm_param": "non_norm_param"}). To remove any parameter in
115+
the `norm` function (or class) that does not exist in the `non_norm` function,
116+
simply set the value to `None` (i.e., {"norm_param": None}).
114117
115118
Returns
116119
-------

stumpy/gpu_stimp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from .stimp import _stimp
88

99

10-
@core.non_normalized(gpu_aamp_stimp)
10+
@core.non_normalized(
11+
gpu_aamp_stimp,
12+
exclude=["pre_scrump", "normalize", "p", "pre_scraamp"],
13+
replace={"pre_scrump": "pre_scraamp"},
14+
)
1115
class gpu_stimp(_stimp):
1216
"""
1317
Compute the Pan Matrix Profile with with one or more GPU devices

stumpy/stimp.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
mp_func : object, default stump
168168
The matrix profile function to use when `percentage = 1.0`
169169
"""
170-
self._T = T
170+
self._T = T.copy()
171171
if max_m is None:
172172
max_m = max(min_m + 1, core.get_max_window_size(self._T.shape[0]))
173173
M = np.arange(min_m, max_m + 1, step).astype(np.int64)
@@ -320,7 +320,11 @@ def M_(self):
320320
# return self._n_processed
321321

322322

323-
@core.non_normalized(aamp_stimp)
323+
@core.non_normalized(
324+
aamp_stimp,
325+
exclude=["pre_scrump", "normalize", "p", "pre_scraamp"],
326+
replace={"pre_scrump": "pre_scraamp"},
327+
)
324328
class stimp(_stimp):
325329
"""
326330
Compute the Pan Matrix Profile
@@ -464,7 +468,11 @@ def __init__(
464468
)
465469

466470

467-
@core.non_normalized(aamp_stimped)
471+
@core.non_normalized(
472+
aamp_stimped,
473+
exclude=["pre_scrump", "normalize", "p", "pre_scraamp"],
474+
replace={"pre_scrump": "pre_scraamp"},
475+
)
468476
class stimped(_stimp):
469477
"""
470478
Compute the Pan Matrix Profile with a distributed dask cluster

test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ test_custom()
7979
# Test one or more user-defined functions repeatedly
8080
for VARIABLE in {1..10}
8181
do
82-
pytest -x -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_.py
82+
pytest -x -W ignore::DeprecationWarning tests/test_.py
8383
check_errs $?
8484
done
8585
clean_up
@@ -138,7 +138,7 @@ test_unit()
138138
pytest -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_gpu_aamp_stimp.py
139139
pytest -x -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_aamp_stimp.py
140140
check_errs $?
141-
pytest -x -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_non_normalized_decorator.py
141+
pytest -x -W ignore::DeprecationWarning tests/test_non_normalized_decorator.py
142142
check_errs $?
143143
}
144144

tests/test_non_normalized_decorator.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dask.distributed import Client, LocalCluster
77
from numba import cuda
88

9+
import naive
10+
911
try:
1012
from numba.errors import NumbaPerformanceWarning
1113
except ModuleNotFoundError:
@@ -352,18 +354,65 @@ def test_snippets():
352354

353355
@pytest.mark.parametrize("T, m", test_data)
354356
def test_stimp(T, m):
357+
if T.ndim > 1:
358+
T = T.copy()
359+
T = T[0]
360+
n = 3
361+
seed = np.random.randint(100000)
362+
363+
np.random.seed(seed)
355364
ref = stumpy.aamp_stimp(T, m)
356-
comp = stumpy.stimp(T, m, normalize=False)
357-
npt.assert_almost_equal(ref.PAN_, comp.PAN_)
365+
for i in range(n):
366+
ref.update()
367+
368+
np.random.seed(seed)
369+
cmp = stumpy.stimp(T, m, normalize=False)
370+
for i in range(n):
371+
cmp.update()
372+
373+
# Compare raw pan
374+
ref_PAN = ref._PAN
375+
cmp_PAN = cmp._PAN
376+
377+
naive.replace_inf(ref_PAN)
378+
naive.replace_inf(cmp_PAN)
379+
380+
npt.assert_almost_equal(ref_PAN, cmp_PAN)
381+
382+
# Compare transformed pan
383+
npt.assert_almost_equal(ref.PAN_, cmp.PAN_)
358384

359385

360386
@pytest.mark.filterwarnings("ignore:\\s+Port 8787 is already in use:UserWarning")
361387
@pytest.mark.parametrize("T, m", test_data)
362388
def test_stimped(T, m, dask_cluster):
389+
if T.ndim > 1:
390+
T = T.copy()
391+
T = T[0]
392+
n = 3
393+
seed = np.random.randint(100000)
363394
with Client(dask_cluster) as dask_client:
395+
np.random.seed(seed)
364396
ref = stumpy.aamp_stimped(dask_client, T, m)
365-
comp = stumpy.stimped(dask_client, T, m, normalize=False)
366-
npt.assert_almost_equal(ref.PAN_, comp.PAN_)
397+
for i in range(n):
398+
ref.update()
399+
400+
np.random.seed(seed)
401+
cmp = stumpy.stimped(dask_client, T, m, normalize=False)
402+
for i in range(n):
403+
cmp.update()
404+
405+
# Compare raw pan
406+
ref_PAN = ref._PAN
407+
cmp_PAN = cmp._PAN
408+
409+
naive.replace_inf(ref_PAN)
410+
naive.replace_inf(cmp_PAN)
411+
412+
npt.assert_almost_equal(ref_PAN, cmp_PAN)
413+
414+
# Compare transformed pan
415+
npt.assert_almost_equal(ref.PAN_, cmp.PAN_)
367416

368417

369418
@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)
@@ -372,6 +421,30 @@ def test_gpu_stimp(T, m):
372421
if not cuda.is_available(): # pragma: no cover
373422
pytest.skip("Skipping Tests No GPUs Available")
374423

424+
if T.ndim > 1:
425+
T = T.copy()
426+
T = T[0]
427+
n = 3
428+
seed = np.random.randint(100000)
429+
430+
np.random.seed(seed)
375431
ref = stumpy.gpu_aamp_stimp(T, m)
376-
comp = stumpy.gpu_stimp(T, m, normalize=False)
377-
npt.assert_almost_equal(ref.PAN_, comp.PAN_)
432+
for i in range(n):
433+
ref.update()
434+
435+
np.random.seed(seed)
436+
cmp = stumpy.gpu_stimp(T, m, normalize=False)
437+
for i in range(n):
438+
cmp.update()
439+
440+
# Compare raw pan
441+
ref_PAN = ref._PAN
442+
cmp_PAN = cmp._PAN
443+
444+
naive.replace_inf(ref_PAN)
445+
naive.replace_inf(cmp_PAN)
446+
447+
npt.assert_almost_equal(ref_PAN, cmp_PAN)
448+
449+
# Compare transformed pan
450+
npt.assert_almost_equal(ref.PAN_, cmp.PAN_)

0 commit comments

Comments
 (0)