Note
Go to the end to download the full example code.
Compute Workgroups
A simple compute example demonstrating GPU workgroups and invocation IDs.
Each thread writes its global, local, and workgroup IDs into a storage buffer so the relationship between them can be inspected.
import numpy as np
import wgpu
# define workgroup configuration
workgroup_size = 4
workgroups = 3
total_threads = workgroup_size * workgroups
# Each thread writes 3 uint32 values
output_elements = total_threads * 3
output_bytes = output_elements * 4
# compute shader
shader_source = f"""
@group(0) @binding(0)
var<storage, read_write> out: array<u32>;
@compute
@workgroup_size({workgroup_size}, 1, 1)
fn main(
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(workgroup_id) wg_id : vec3<u32>,
) {{
let base: u32 = global_id.x * 3u;
out[base] = global_id.x;
out[base + 1] = local_id.x;
out[base + 2] = wg_id.x;
}}
"""
# adapter and device
adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
device = adapter.request_device_sync()
# storage buffer
output_buffer = device.create_buffer(
size=output_bytes,
usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_SRC,
)
# Shader module and compute pipeline
shader_module = device.create_shader_module(code=shader_source)
compute_pipeline = device.create_compute_pipeline(
layout="auto",
compute={"module": shader_module, "entry_point": "main"},
)
# bind group
bind_group = device.create_bind_group(
layout=compute_pipeline.get_bind_group_layout(0),
entries=[
{
"binding": 0,
"resource": {
"buffer": output_buffer,
"offset": 0,
"size": output_buffer.size,
},
}
],
)
# encode, dispatch and submit
command_encoder = device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()
compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group)
compute_pass.dispatch_workgroups(workgroups, 1, 1)
compute_pass.end()
device.queue.submit([command_encoder.finish()])
# results
raw = device.queue.read_buffer(output_buffer)
values = np.frombuffer(raw, dtype=np.uint32)
print(
f"Dispatched {workgroups} workgroup(s) of {workgroup_size} thread(s) each "
f"({total_threads} threads total).\n"
)
print(f"{'Thread':>6} {'global_id':>9} {'local_id':>8} {'workgroup_id':>12} ")
for i in range(total_threads):
global_id = values[i * 3]
local_id = values[i * 3 + 1]
workgroup_id = values[i * 3 + 2]
print(f"{i:>6} {global_id:>9} {local_id:>8} {workgroup_id:>12}")
# verify invocation ID relationships
assert global_id == i
assert local_id == i % workgroup_size
assert workgroup_id == i // workgroup_size
print("Invocation ID mapping verified")
Gallery generated by Sphinx-Gallery
Interactive example
This uses Pyodide. If this does not work, your browser may not have sufficient support for wasm/pyodide/wgpu (check your browser dev console). Stdout (print statements) will also appear in the browser console.