66from dask .distributed import Client , LocalCluster
77from numba import cuda
88
9+ import naive
10+
911try :
1012 from numba .errors import NumbaPerformanceWarning
1113except ModuleNotFoundError :
@@ -352,18 +354,65 @@ def test_snippets():
352354
353355@pytest .mark .parametrize ("T, m" , test_data )
354356def test_stimp (T , m ):
357+ if T .ndim > 1 :
358+ T = T .copy ()
359+ T = T [0 ]
360+ n = 3
361+ seed = np .random .randint (100000 )
362+
363+ np .random .seed (seed )
355364 ref = stumpy .aamp_stimp (T , m )
356- comp = stumpy .stimp (T , m , normalize = False )
357- npt .assert_almost_equal (ref .PAN_ , comp .PAN_ )
365+ for i in range (n ):
366+ ref .update ()
367+
368+ np .random .seed (seed )
369+ cmp = stumpy .stimp (T , m , normalize = False )
370+ for i in range (n ):
371+ cmp .update ()
372+
373+ # Compare raw pan
374+ ref_PAN = ref ._PAN
375+ cmp_PAN = cmp ._PAN
376+
377+ naive .replace_inf (ref_PAN )
378+ naive .replace_inf (cmp_PAN )
379+
380+ npt .assert_almost_equal (ref_PAN , cmp_PAN )
381+
382+ # Compare transformed pan
383+ npt .assert_almost_equal (ref .PAN_ , cmp .PAN_ )
358384
359385
360386@pytest .mark .filterwarnings ("ignore:\\ s+Port 8787 is already in use:UserWarning" )
361387@pytest .mark .parametrize ("T, m" , test_data )
362388def test_stimped (T , m , dask_cluster ):
389+ if T .ndim > 1 :
390+ T = T .copy ()
391+ T = T [0 ]
392+ n = 3
393+ seed = np .random .randint (100000 )
363394 with Client (dask_cluster ) as dask_client :
395+ np .random .seed (seed )
364396 ref = stumpy .aamp_stimped (dask_client , T , m )
365- comp = stumpy .stimped (dask_client , T , m , normalize = False )
366- npt .assert_almost_equal (ref .PAN_ , comp .PAN_ )
397+ for i in range (n ):
398+ ref .update ()
399+
400+ np .random .seed (seed )
401+ cmp = stumpy .stimped (dask_client , T , m , normalize = False )
402+ for i in range (n ):
403+ cmp .update ()
404+
405+ # Compare raw pan
406+ ref_PAN = ref ._PAN
407+ cmp_PAN = cmp ._PAN
408+
409+ naive .replace_inf (ref_PAN )
410+ naive .replace_inf (cmp_PAN )
411+
412+ npt .assert_almost_equal (ref_PAN , cmp_PAN )
413+
414+ # Compare transformed pan
415+ npt .assert_almost_equal (ref .PAN_ , cmp .PAN_ )
367416
368417
369418@pytest .mark .filterwarnings ("ignore" , category = NumbaPerformanceWarning )
@@ -372,6 +421,30 @@ def test_gpu_stimp(T, m):
372421 if not cuda .is_available (): # pragma: no cover
373422 pytest .skip ("Skipping Tests No GPUs Available" )
374423
424+ if T .ndim > 1 :
425+ T = T .copy ()
426+ T = T [0 ]
427+ n = 3
428+ seed = np .random .randint (100000 )
429+
430+ np .random .seed (seed )
375431 ref = stumpy .gpu_aamp_stimp (T , m )
376- comp = stumpy .gpu_stimp (T , m , normalize = False )
377- npt .assert_almost_equal (ref .PAN_ , comp .PAN_ )
432+ for i in range (n ):
433+ ref .update ()
434+
435+ np .random .seed (seed )
436+ cmp = stumpy .gpu_stimp (T , m , normalize = False )
437+ for i in range (n ):
438+ cmp .update ()
439+
440+ # Compare raw pan
441+ ref_PAN = ref ._PAN
442+ cmp_PAN = cmp ._PAN
443+
444+ naive .replace_inf (ref_PAN )
445+ naive .replace_inf (cmp_PAN )
446+
447+ npt .assert_almost_equal (ref_PAN , cmp_PAN )
448+
449+ # Compare transformed pan
450+ npt .assert_almost_equal (ref .PAN_ , cmp .PAN_ )
0 commit comments