Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ set(WEBGPU_SRCS
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
runtime/ops/rope/RotaryEmbedding.cpp
runtime/ops/prepack/Prepack.cpp
runtime/ops/view_copy/ViewCopy.cpp
runtime/ops/select/Select.cpp
runtime/ops/sigmoid/UnaryOp.cpp
runtime/ops/squeeze/Squeeze.cpp
runtime/ops/unsqueeze/Unsqueeze.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
20 changes: 20 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,16 @@ void WebGPUGraph::execute() {
// One pass per dispatch: enforces storage RAW ordering across deps.
for (size_t i = 0; i < n; i++) {
const auto& dispatch = dispatches_[i];
if (dispatch.kind == WebGPUDispatch::Kind::Copy) {
wgpuCommandEncoderCopyBufferToBuffer(
encoder,
dispatch.copy_src,
0,
dispatch.copy_dst,
0,
dispatch.copy_nbytes);
continue;
}
WGPUComputePassDescriptor pass_desc = {};
#ifdef WGPU_BACKEND_ENABLE_PROFILING
// tw must outlive BeginComputePass (the descriptor points at it).
Expand Down Expand Up @@ -757,6 +767,16 @@ void WebGPUGraph::execute() {
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);

for (size_t i = start; i < end; i++) {
if (dispatches_[i].kind == WebGPUDispatch::Kind::Copy) {
wgpuCommandEncoderCopyBufferToBuffer(
encoder,
dispatches_[i].copy_src,
0,
dispatches_[i].copy_dst,
0,
dispatches_[i].copy_nbytes);
continue;
}
WGPUComputePassDescriptor pass_desc = {};
WGPUComputePassEncoder pass =
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
Expand Down
17 changes: 17 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ struct WebGPUDispatch {
WGPUBindGroup bind_group = nullptr;
uint32_t workgroup_count_x = 1;
std::string kernel_name; // bench label
// DMA copy command; default Compute keeps existing positional inits valid.
enum class Kind { Compute, Copy };
Kind kind = Kind::Compute;
WGPUBuffer copy_src = nullptr;
WGPUBuffer copy_dst = nullptr;
size_t copy_nbytes = 0;
};

struct OutputCopy {
Expand Down Expand Up @@ -189,6 +195,17 @@ class WebGPUGraph {
dispatches_.push_back(dispatch);
}

// Record an in-graph-order buffer-to-buffer DMA (e.g. a flat copy).
void add_buffer_copy(WGPUBuffer src, WGPUBuffer dst, size_t nbytes) {
WebGPUDispatch d;
d.kind = WebGPUDispatch::Kind::Copy;
d.copy_src = src;
d.copy_dst = dst;
d.copy_nbytes = nbytes;
d.kernel_name = "flat_copy";
dispatches_.push_back(d);
}

// Materialize a recorded prepack-routed constant into dst via one CPU->GPU
// transfer. Build-time only (the .pte bytes are freed after build()).
// Mirrors Vulkan prepack_standard.
Expand Down
184 changes: 184 additions & 0 deletions backends/webgpu/runtime/ops/select/Select.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
#include <executorch/backends/webgpu/runtime/ops/select/select_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

struct SelectParams {
uint32_t dim;
uint32_t index;
uint32_t _pad[2];
};

// dim/index are required Ints (SymInt throws); no Null default unlike slice.
int64_t read_scalar(WebGPUGraph& graph, int id, const char* what) {
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Int) {
return graph.get_int(id);
}
throw std::runtime_error(std::string("select: dynamic/unsupported ") + what);
}

void select_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, dim, index, out]; output rank = in rank - 1.
const int in_id = args.at(0);
const int out_id = args.at(3);

WGPUDevice device = graph.device();
const auto& in_tensor = graph.get_tensor(in_id);
const auto& out_tensor = graph.get_tensor(out_id);
if (in_tensor.buffer == nullptr || out_tensor.buffer == nullptr) {
throw std::runtime_error("select: null buffer binding");
}

const int in_ndim = static_cast<int>(in_tensor.dims.size());
int64_t dim = read_scalar(graph, args.at(1), "dim");
if (dim < 0) {
dim += in_ndim;
}
if (dim < 0 || dim >= in_ndim) {
throw std::runtime_error("select: dim out of range");
}
const int64_t in_size = in_tensor.dims[dim];
int64_t index = read_scalar(graph, args.at(2), "index");
if (index < 0) {
index += in_size;
}
if (index < 0 || index >= in_size) {
throw std::runtime_error("select: index out of range");
}

TensorMeta out_meta;
TensorMeta in_meta;
fill_tensor_meta(out_tensor, &out_meta);
fill_tensor_meta(in_tensor, &in_meta);
if (out_tensor.nbytes !=
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
in_tensor.nbytes != static_cast<size_t>(in_meta.numel) * sizeof(float)) {
throw std::runtime_error("select: non-fp32 operand (nbytes != numel * 4)");
}

SelectParams params = {};
params.dim = static_cast<uint32_t>(dim);
params.index = static_cast<uint32_t>(index);

uint32_t wg_size = utils::clamp_workgroup_size(device, kSelectWorkgroupSizeX);
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, out_meta.numel, wg_size, "select");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

WGPUBuffer out_meta_buf =
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
WGPUBuffer in_meta_buf =
utils::make_uniform(device, &in_meta, sizeof(TensorMeta));
WGPUBuffer params_buf =
utils::make_uniform(device, &params, sizeof(SelectParams));
graph.add_uniform_buffer_bytes(2 * sizeof(TensorMeta) + sizeof(SelectParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kSelectWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// Bind group: in, out (rw), out_meta, in_meta, params (3 uniforms).
WGPUBindGroupLayoutEntry entries[5] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_Uniform;
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
entries[4].binding = 4;
entries[4].visibility = WGPUShaderStage_Compute;
entries[4].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 5;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[5] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = in_tensor.buffer;
bg_entries[0].size = in_tensor.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = out_meta_buf;
bg_entries[2].size = sizeof(TensorMeta);
bg_entries[3].binding = 3;
bg_entries[3].buffer = in_meta_buf;
bg_entries[3].size = sizeof(TensorMeta);
bg_entries[4].binding = 4;
bg_entries[4].buffer = params_buf;
bg_entries[4].size = sizeof(SelectParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 5;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our refs; the bind group keeps the uniforms alive until release.
wgpuBufferRelease(out_meta_buf);
wgpuBufferRelease(in_meta_buf);
wgpuBufferRelease(params_buf);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.select_copy.int, select_impl);
}

} // namespace executorch::backends::webgpu
41 changes: 41 additions & 0 deletions backends/webgpu/runtime/ops/select/select.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;

struct Params {
dim: u32,
index: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= out_meta.numel) {
return;
}

// Gather: out dim od -> in dim (od if od < dim else od+1); sel dim = index.
var rem = out_bufi;
var in_bufi: u32 = params.index * in_meta.strides[params.dim];
for (var od: u32 = 0u; od < out_meta.ndim; od = od + 1u) {
let coord = rem / out_meta.strides[od];
rem = rem % out_meta.strides[od];
var id = od;
if (od >= params.dim) {
id = od + 1u;
}
in_bufi = in_bufi + coord * in_meta.strides[id];
}
output[out_bufi] = input[in_bufi];
}
65 changes: 65 additions & 0 deletions backends/webgpu/runtime/ops/select/select_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from select.wgsl - DO NOT EDIT.
// wgsl-sha256: 200cf5a8190045aa0562e782f01c1cfaf9681f30f679f5112ccc3d347a0ed8df
inline constexpr const char* kSelectWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;

struct Params {
dim: u32,
index: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= out_meta.numel) {
return;
}

// Gather: out dim od -> in dim (od if od < dim else od+1); sel dim = index.
var rem = out_bufi;
var in_bufi: u32 = params.index * in_meta.strides[params.dim];
for (var od: u32 = 0u; od < out_meta.ndim; od = od + 1u) {
let coord = rem / out_meta.strides[od];
rem = rem % out_meta.strides[od];
var id = od;
if (od >= params.dim) {
id = od + 1u;
}
in_bufi = in_bufi + coord * in_meta.strides[id];
}
output[out_bufi] = input[in_bufi];
}
)";

inline constexpr uint32_t kSelectWorkgroupSizeX = 64;
inline constexpr uint32_t kSelectWorkgroupSizeY = 1;
inline constexpr uint32_t kSelectWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading
Loading