Skip to content

Commit 078cdc3

Browse files
committed
update async usage
1 parent f190c3c commit 078cdc3

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

wgpu/backends/js_webgpu/__init__.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ class GPU(classes.GPU):
136136
def __init__(self):
137137
self._internal = window.navigator.gpu # noqa: F821
138138

139-
def request_adapter_sync(self, **options):
140-
promise = self.request_adapter_async(**options)
141-
result = promise.sync_wait()
142-
return result
143-
144139
def request_adapter_async(self, loop=None, canvas=None, **options) -> GPUPromise["GPUAdapter"]:
145140
options = structs.RequestAdapterOptions(**options)
146141
js_options = to_js(options, eager_converter=simple_js_accessor)
@@ -160,6 +155,7 @@ def adapter_constructor(js_adapter):
160155

161156
# api diff not really useful, but needed for compatibility I guess?
162157
# because only the call to _async imports the auto module -> this method doesn't get overwritten... so nothing happens here.
158+
# TODO: most likely remove
163159
def enumerate_adapters_sync(self) -> list["GPUAdapter"]:
164160
print("we are in the jswebgpu backend, but unreachable :'(")
165161

@@ -215,17 +211,19 @@ def __init__(self, js_adapter, loop):
215211

216212
super().__init__(internal=internal, features=features, limits=py_limits, adapter_info=adapter_info, loop=loop)
217213

214+
def request_device_async(self, **kwargs) -> GPUPromise["GPUDevice"]:
215+
descriptor = structs.DeviceDescriptor(**kwargs)
216+
js_descriptor = to_js(descriptor, eager_converter=simple_js_accessor)
217+
js_device_promise = self._internal.requestDevice(js_descriptor)
218218

219+
label = kwargs.get("label", "")
220+
def device_constructor(js_device):
221+
# TODO: do we need to hand down a default_queue here?
222+
return GPUDevice(label, js_device, adapter=self)
219223

220-
def request_device_sync(self, **parameters):
221-
return run_sync(self.request_device_async(**parameters))
222-
# raise NotImplementedError("Cannot use sync API functions in JS.")
223-
224-
async def request_device_async(self, **parameters):
225-
label = parameters.get("label", "")
226-
js_device = await self._internal.requestDevice(**parameters)
227-
default_queue = parameters.get("default_queue", {})
228-
return GPUDevice(label, js_device, adapter=self)
224+
promise = GPUPromise("request_device", device_constructor, loop=self._loop)
225+
js_device_promise.then(promise._set_input)
226+
return promise
229227

230228

231229
class GPUDevice(classes.GPUDevice):
@@ -275,6 +273,7 @@ def create_buffer_with_data_(self, *, label="", data, usage: flags.BufferUsageFl
275273
data_size = (data.nbytes + 3) & ~3 # align to 4 bytes
276274

277275
# if it's a Descriptor you need the keywords
276+
# do we need to also need to modify the usages?
278277
js_buf = self._internal.createBuffer(label=label, size=data_size, usage=usage, mappedAtCreation=True)
279278
# print("created buffer", js_buf, dir(js_buf), js_buf.size)
280279
array_buf = js_buf.getMappedRange(0, data_size)
@@ -301,6 +300,20 @@ def create_compute_pipeline(self, **kwargs):
301300
label = kwargs.get("label", "")
302301
return GPUComputePipeline(label, js_cp, self)
303302

303+
# TODO: no example tests this!
304+
def create_compute_pipeline_async(self, **kwargs):
305+
descriptor = structs.ComputePipelineDescriptor(**kwargs)
306+
js_descriptor = to_js(descriptor, eager_converter=simple_js_accessor)
307+
js_promise = self._internal.createComputePipelineAsync(js_descriptor)
308+
309+
label = kwargs.get("label", "")
310+
def construct_compute_pipeline(js_cp):
311+
return GPUComputePipeline(label, js_cp, self)
312+
promise = GPUPromise("create_compute_pipeline", construct_compute_pipeline, loop=self._loop)
313+
js_promise.then(promise._set_input)
314+
315+
return promise
316+
304317
def create_bind_group(self, **kwargs) -> classes.GPUBindGroup:
305318
descriptor = structs.BindGroupDescriptor(**kwargs)
306319
js_descriptor = to_js(descriptor, eager_converter=simple_js_accessor)
@@ -404,12 +417,12 @@ def write_mapped(self, data, buffer_offset: int | None = None):
404417
array_buf = self._internal.getMappedRange(buffer_offset, size)
405418
Uint8Array.new(array_buf).assign(data)
406419

407-
def map_sync(self, mode=None, offset=0, size=None):
408-
return run_sync(self.map_async(mode, offset, size))
420+
def map_async(self, mode: flags.MapModeFlags | None, offset: int = 0, size: int | None = None) -> GPUPromise[None]:
421+
map_promise = self._internal.mapAsync(mode, offset, size)
409422

410-
async def map_async(self, mode: flags.MapModeFlags | None, offset: int = 0, size: int | None = None):
411-
res = await self._internal.mapAsync(mode, offset, size)
412-
return res
423+
promise = GPUPromise("buffer.map_async", lambda _: None, loop=self._device._loop)
424+
map_promise.then(promise._set_input) # presumably this signals via a none callback to nothing?
425+
return promise
413426

414427
def unmap(self):
415428
self._internal.unmap()
@@ -543,11 +556,12 @@ def read_buffer(self, buffer: GPUBuffer, buffer_offset: int=0, size: int | None
543556
)
544557

545558
js_encoder = device._internal.createCommandEncoder()
546-
# todo: somehow test if all the offset math is correct
559+
# TODO: somehow test if all the offset math is correct
547560
js_encoder.copyBufferToBuffer(buffer._internal, buffer_offset, js_temp_buffer, buffer_offset, data_length)
548561
self._internal.submit([js_encoder.finish()])
549562

550563
# best way to await the promise directly?
564+
# TODO: can we do more steps async before waiting?
551565
run_sync(js_temp_buffer.mapAsync(flags.MapMode.READ, 0, data_length))
552566
array_buf = js_temp_buffer.getMappedRange()
553567
res = array_buf.slice(0)
@@ -596,6 +610,10 @@ def write_buffer(
596610

597611
self._internal.writeBuffer(buffer._internal, buffer_offset, js_data, data_offset, size)
598612

613+
def on_submitted_work_done_async(self):
614+
# TODO: this could be interesting with promises now
615+
raise NotImplementedError("hope for some examples to test this against first")
616+
599617

600618

601619
class GPUTexture(classes.GPUTexture):

0 commit comments

Comments
 (0)