Skip to content

Commit f8d0250

Browse files
authored
Add dpctl.get_curent_device_type() (#88)
1 parent 7f3b868 commit f8d0250

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

dpctl/sycl_core.pyx

+19
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ cdef class SyclDevice:
111111
'''
112112
return self._device_name.decode()
113113

114+
def get_device_type (self):
115+
''' Returns the type of the device as a `device_type` enum
116+
'''
117+
if DPPLDevice_IsGPU(self._device_ref):
118+
return device_type.gpu
119+
elif DPPLDevice_IsCPU(self._device_ref):
120+
return device_type.cpu
121+
else:
122+
raise ValueError("Unknown device type.")
123+
114124
def get_vendor_name (self):
115125
''' Returns the device vendor name as a string
116126
'''
@@ -515,6 +525,14 @@ cdef class _SyclQueueManager:
515525
else:
516526
return False
517527

528+
def get_current_device_type (self):
529+
''' Returns current device type as `device_type` enum
530+
'''
531+
if self.is_in_device_context():
532+
return self.get_current_queue().get_sycl_device().get_device_type()
533+
else:
534+
return None
535+
518536

519537
# This private instance of the _SyclQueueManager should not be directly
520538
# accessed outside the module.
@@ -523,6 +541,7 @@ _qmgr = _SyclQueueManager()
523541
# Global bound functions
524542
dump = _qmgr.dump
525543
get_current_queue = _qmgr.get_current_queue
544+
get_current_device_type = _qmgr.get_current_device_type
526545
get_num_platforms = _qmgr.get_num_platforms
527546
get_num_activated_queues = _qmgr.get_num_activated_queues
528547
has_cpu_queues = _qmgr.has_cpu_queues

dpctl/tests/test_sycl_queue_manager.py

+33
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
import dpctl
2626
import unittest
2727

28+
2829
class TestGetNumPlatforms (unittest.TestCase):
2930
@unittest.skipIf(not dpctl.has_sycl_platforms(),
3031
"No SYCL platforms available")
3132
def test_dpctl_get_num_platforms (self):
3233
if(dpctl.has_sycl_platforms):
3334
self.assertGreaterEqual(dpctl.get_num_platforms(), 1)
3435

36+
3537
@unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available")
3638
class TestDumpMethods (unittest.TestCase):
3739
def test_dpctl_dump (self):
@@ -47,6 +49,7 @@ def test_dpctl_dump_device_info (self):
4749
except Exception:
4850
self.fail("Encountered an exception inside dump_device_info().")
4951

52+
5053
@unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available")
5154
class TestIsInDeviceContext (unittest.TestCase):
5255

@@ -65,6 +68,35 @@ def test_is_in_device_context_inside_nested_device_ctxt (self):
6568
self.assertTrue(dpctl.is_in_device_context())
6669
self.assertFalse(dpctl.is_in_device_context())
6770

71+
72+
@unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available")
73+
class TestIsInDeviceContext (unittest.TestCase):
74+
75+
def test_get_current_device_type_outside_device_ctxt (self):
76+
self.assertEqual(dpctl.get_current_device_type(), None)
77+
78+
def test_get_current_device_type_inside_device_ctxt (self):
79+
self.assertEqual(dpctl.get_current_device_type(), None)
80+
81+
with dpctl.device_context(dpctl.device_type.gpu):
82+
self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.gpu)
83+
84+
self.assertEqual(dpctl.get_current_device_type(), None)
85+
86+
@unittest.skipIf(not dpctl.has_cpu_queues(), "No CPU platforms available")
87+
def test_get_current_device_type_inside_nested_device_ctxt (self):
88+
self.assertEqual(dpctl.get_current_device_type(), None)
89+
90+
with dpctl.device_context(dpctl.device_type.cpu):
91+
self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.cpu)
92+
93+
with dpctl.device_context(dpctl.device_type.gpu):
94+
self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.gpu)
95+
self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.cpu)
96+
97+
self.assertEqual(dpctl.get_current_device_type(), None)
98+
99+
68100
@unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available")
69101
class TestGetCurrentQueueInMultipleThreads (unittest.TestCase):
70102

@@ -96,5 +128,6 @@ def SessionThread (self):
96128
Session1.start()
97129
Session2.start()
98130

131+
99132
if __name__ == '__main__':
100133
unittest.main()

0 commit comments

Comments
 (0)