Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] WebGPU backend #1790

Open
zcbenz opened this issue Jan 24, 2025 · 0 comments
Open

[Feature] WebGPU backend #1790

zcbenz opened this issue Jan 24, 2025 · 0 comments

Comments

@zcbenz
Copy link
Contributor

zcbenz commented Jan 24, 2025

I have been evaluating the possibility of using WebGPU as a new GPU backend of MLX, and I'm confident to say it is totally feasible.

What is WebGPU

WebGPU is a set of APIs and a new shading language that, get translated to native graphics stacks (Metal, Vulkan, DirectX, OpenGL), meaning we can write once to get support for all major platforms.

While it has "web" in its name, the WebGPU project also provides C/C++ APIs that works across different implementations, which can be used in MLX.

(Using WebGPU also has a bonus of supporting the web, i.e. compiling MLX to wasm and running it in browsers, though I'm not going to cover it here.)

Approach to WebGPU in MLX

I'm very interested in making this happen, but I'm also aware MLX team does not have the resources to maintain a new backend or review all the WebGPU code, so I propose doing it in separate layers: I write WebGPU kernels in a separate project, while the upstream MLX maintains a minimal backend implementation that uses it.

As a proof of concept, I have implemented the binary ops in WGSL (WebGPU Shading Language) with a wrapper on top of dawn (a WebGPU implementation): https://github.com/frost-beta/betann, and added a naive backend in MLX: #1789.

What is needed from MLX

To support WebGPU backend there are a few changes required in MLX:

  • The array class needs APIs to manage separate GPU data from CPU data, and to transfer data between.
  • There is a new array status: the array has been evaluated but not read into CPU. In my PR I just always load the data into CPU after evaluation, but in production the transfer should only happen when user needs the data.
  • Apart from the metal API, there should be a new API for general GPU.

It will take quite a while before I write enough WebGPU kernels to make simple examples work, but I think it would be very helpful if MLX can start making its architecture compatible with general GPU backends, even if the WebGPU backend does not happen in the end, the work will still be helpful for adding other backends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant