25
25
import dpctl
26
26
import unittest
27
27
28
+
28
29
class TestGetNumPlatforms (unittest .TestCase ):
29
30
@unittest .skipIf (not dpctl .has_sycl_platforms (),
30
31
"No SYCL platforms available" )
31
32
def test_dpctl_get_num_platforms (self ):
32
33
if (dpctl .has_sycl_platforms ):
33
34
self .assertGreaterEqual (dpctl .get_num_platforms (), 1 )
34
35
36
+
35
37
@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
36
38
class TestDumpMethods (unittest .TestCase ):
37
39
def test_dpctl_dump (self ):
@@ -47,6 +49,7 @@ def test_dpctl_dump_device_info (self):
47
49
except Exception :
48
50
self .fail ("Encountered an exception inside dump_device_info()." )
49
51
52
+
50
53
@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
51
54
class TestIsInDeviceContext (unittest .TestCase ):
52
55
@@ -65,6 +68,35 @@ def test_is_in_device_context_inside_nested_device_ctxt (self):
65
68
self .assertTrue (dpctl .is_in_device_context ())
66
69
self .assertFalse (dpctl .is_in_device_context ())
67
70
71
+
72
+ @unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
73
+ class TestIsInDeviceContext (unittest .TestCase ):
74
+
75
+ def test_get_current_device_type_outside_device_ctxt (self ):
76
+ self .assertEqual (dpctl .get_current_device_type (), None )
77
+
78
+ def test_get_current_device_type_inside_device_ctxt (self ):
79
+ self .assertEqual (dpctl .get_current_device_type (), None )
80
+
81
+ with dpctl .device_context (dpctl .device_type .gpu ):
82
+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
83
+
84
+ self .assertEqual (dpctl .get_current_device_type (), None )
85
+
86
+ @unittest .skipIf (not dpctl .has_cpu_queues (), "No CPU platforms available" )
87
+ def test_get_current_device_type_inside_nested_device_ctxt (self ):
88
+ self .assertEqual (dpctl .get_current_device_type (), None )
89
+
90
+ with dpctl .device_context (dpctl .device_type .cpu ):
91
+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
92
+
93
+ with dpctl .device_context (dpctl .device_type .gpu ):
94
+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .gpu )
95
+ self .assertEqual (dpctl .get_current_device_type (), dpctl .device_type .cpu )
96
+
97
+ self .assertEqual (dpctl .get_current_device_type (), None )
98
+
99
+
68
100
@unittest .skipIf (not dpctl .has_sycl_platforms (), "No SYCL platforms available" )
69
101
class TestGetCurrentQueueInMultipleThreads (unittest .TestCase ):
70
102
@@ -96,5 +128,6 @@ def SessionThread (self):
96
128
Session1 .start ()
97
129
Session2 .start ()
98
130
131
+
99
132
if __name__ == '__main__' :
100
133
unittest .main ()
0 commit comments