You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The implementation of the upstream Pytorch is using the indexing of the order of the devices enumerated from the SYCL API as the torch device identity to refer the underlaying SYCL device.
No extra sorting the enumeration results.
No extra tiling and sub-partitioning on the SYCL devices.
No extra filter and reordering on iGPU and dGPU.
To support JIT the Triton kernel with Pytorch framework correctly, the Triton could enumerate the SYCL devices from SYCL runtime by the same practice. And the torch device identity should map to the same underlaying SYCL device correctly.
For long term, we want to decouple this logic of the Pytorch and Triton.
We propose that the Pytorch should supply a method to return the SYCL device without the assumption of how the SYCL devices are mapped.
In the NVIDIA backend the active device is loaded directly from PyTorch: https://github.com/triton-lang/triton/blob/main/python/triton/backends/driver.py#L29
There is also a method for getting the current stream, akin to the sycl::queue.
If we can retrieve both those objects from PyTorch at the appropriate time then it is not clear that we need the internal state we are storing in driver.c.
In the code Driver.c, we enumerate the SYCL devices list from the SYCL context directly and save it in an internal vector.
There maybe an issue that the IPEX uses the difference indexing to refer different SYCL device than Triton.
The text was updated successfully, but these errors were encountered: