Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ set(WEBGPU_SRCS
runtime/ops/slice/Slice.cpp
runtime/ops/permute/Permute.cpp
runtime/ops/cat/Cat.cpp
runtime/ops/index/Index.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down Expand Up @@ -193,4 +194,5 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
target_compile_options(webgpu_op_test_util_test PRIVATE -fexceptions)
set_property(TARGET webgpu_op_test_util_test PROPERTY CXX_STANDARD 17)
endif()
add_webgpu_native_test(webgpu_index_test test/native/test_index.cpp)
endif()
189 changes: 189 additions & 0 deletions backends/webgpu/runtime/ops/index/Index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/index/index_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

struct IndexParams {
uint32_t numel;
uint32_t _pad[3]; // pad to 16 bytes
};

// aten.index.Tensor 1D-self gather out[i]=self[index[i]] (mirrors Vulkan).
void index_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, indices (Tensor?[] -> ValueList), out].
const int self_id = args.at(0);
const int list_id = args.at(1);
const int out_id = args.at(args.size() - 1);

if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: self arg is not a tensor");
}
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: out arg is not a tensor");
}
if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) {
throw std::runtime_error("index: indices arg is not a ValueList");
}

// Exactly one non-Null index tensor (mirror Vulkan IndexTensor.cpp:67-69).
const std::vector<int>& ids = graph.get_value_list(list_id);
int index_id = -1;
for (int id : ids) {
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Null) {
continue;
}
if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: index list element is not a tensor");
}
if (index_id != -1) {
throw std::runtime_error("index: expected exactly one index tensor");
}
index_id = id;
}
if (index_id == -1) {
throw std::runtime_error("index: no index tensor provided");
}

WGPUDevice device = graph.device();

const auto& self_tensor = graph.get_tensor(self_id);
const auto& index_tensor = graph.get_tensor(index_id);
const auto& out_tensor = graph.get_tensor(out_id);

if (self_tensor.buffer == nullptr || index_tensor.buffer == nullptr ||
out_tensor.buffer == nullptr) {
throw std::runtime_error("index: null buffer binding");
}
// 1D-self gather: the kernel flat-indexes self by a scalar; fail loud on a
// higher-rank self (mirrors Vulkan index_tensor_buffer's 1D-self contract).
if (self_tensor.dims.size() != 1) {
throw std::runtime_error("index: only 1D self is supported");
}

const size_t out_numel = out_tensor.nbytes / sizeof(float);
if (out_tensor.nbytes != out_numel * sizeof(float) ||
self_tensor.nbytes % sizeof(float) != 0) {
throw std::runtime_error("index: non-fp32 self/out (nbytes != numel * 4)");
}
// Index is the int32 downcast of the int64 advanced index (downcast_64_bit).
const size_t index_numel = index_tensor.nbytes / sizeof(int32_t);
if (index_tensor.nbytes != index_numel * sizeof(int32_t)) {
throw std::runtime_error("index: index buffer is not int32 (nbytes % 4)");
}
// out is one self element per index element (row_width == 1, 1D self).
if (out_numel != index_numel) {
throw std::runtime_error("index: out numel != index numel");
}

uint32_t num_elements = static_cast<uint32_t>(out_numel);
uint32_t wg_size = utils::clamp_workgroup_size(device, kIndexWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "index");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

IndexParams params = {};
params.numel = num_elements;

WGPUBuffer uniform_buffer =
utils::make_uniform(device, &params, sizeof(IndexParams));
graph.add_uniform_buffer_bytes(sizeof(IndexParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kIndexWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// self (read), out (read_write), index (read i32), params (uniform).
WGPUBindGroupLayoutEntry entries[4] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 4;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[4] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = self_tensor.buffer;
bg_entries[0].size = self_tensor.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = index_tensor.buffer;
bg_entries[2].size = index_tensor.nbytes;
bg_entries[3].binding = 3;
bg_entries[3].buffer = uniform_buffer;
bg_entries[3].size = sizeof(IndexParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 4;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// The bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.index.Tensor, index_impl);
}

} // namespace executorch::backends::webgpu
22 changes: 22 additions & 0 deletions backends/webgpu/runtime/ops/index/index.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
46 changes: 46 additions & 0 deletions backends/webgpu/runtime/ops/index/index_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from index.wgsl - DO NOT EDIT.
// wgsl-sha256: daed48e60bfcf2b7420d277576d794137d3bff383aef4f68464c98c8a7235c8e
inline constexpr const char* kIndexWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
)";

inline constexpr uint32_t kIndexWorkgroupSizeX = 64;
inline constexpr uint32_t kIndexWorkgroupSizeY = 1;
inline constexpr uint32_t kIndexWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
7 changes: 7 additions & 0 deletions backends/webgpu/runtime/ops/view_copy/ViewCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,17 @@ void view_copy_impl(WebGPUGraph& graph, const std::vector<int>& args) {
add_flat_copy(graph, args.at(0), args.at(args.size() - 1));
}

// clone = flat copy; survives Vulkan RemoveRedundantOpsTransform in Llama 1B.
void clone_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, memory_format?, out]; out = last value-id.
add_flat_copy(graph, args.at(0), args.at(args.size() - 1));
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.view_copy.default, view_copy_impl);
WEBGPU_REGISTER_OP(aten.clone.default, clone_impl);
}

} // namespace executorch::backends::webgpu
12 changes: 11 additions & 1 deletion backends/webgpu/scripts/test_webgpu_native_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ DISPATCH_ORDER_DIR="/tmp/dispatch_order"
DISPATCH_ORDER_OK=1
UPDATE_CACHE_DIR="/tmp/update_cache"
UPDATE_CACHE_OK=1
INDEX_DIR="/tmp/index"
INDEX_OK=1
EMBEDDING_MODEL="/tmp/webgpu_embedding_q4gsw.pte"
EMBEDDING_INDICES="/tmp/webgpu_embedding_q4gsw_indices.bin"
EMBEDDING_GOLDEN="/tmp/webgpu_embedding_q4gsw_golden.bin"
Expand Down Expand Up @@ -104,6 +106,11 @@ export_update_cache_replay('${UPDATE_CACHE_DIR}')
export_update_cache_negative('${UPDATE_CACHE_DIR}')
" || { echo "WARN: update_cache export failed; skipping update_cache native test"; UPDATE_CACHE_OK=0; }

$PYTHON_EXECUTABLE -c "
from executorch.backends.webgpu.test.ops.index.test_index import export_all_index_models
export_all_index_models('${INDEX_DIR}')
" || { echo "WARN: index export failed; skipping index native test"; INDEX_OK=0; }

# Non-fatal: a failed sdpa export makes the required 4k/8k configs hard-fail in
# webgpu_native_test below (precise per-config error), so don't exit/mask here.
$PYTHON_EXECUTABLE -c "
Expand Down Expand Up @@ -136,7 +143,7 @@ cmake \
"${EXECUTORCH_ROOT}"

# ── Build + run every native test target that exists in this tree ────────────
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test)
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test)
BIN_DIR="${BUILD_DIR}/backends/webgpu"

# Which targets are defined depends on which diffs are landed (native_test +
Expand Down Expand Up @@ -201,6 +208,9 @@ fi
if [[ "${DISPATCH_ORDER_OK}" == "1" && -x "${BIN_DIR}/webgpu_dispatch_order_test" ]]; then
"${BIN_DIR}/webgpu_dispatch_order_test" "${DISPATCH_ORDER_DIR}"
fi
if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then
"${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}"
fi
[[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test"

echo "=== WebGPU native tests on Dawn: all run targets passed ==="
Expand Down
13 changes: 13 additions & 0 deletions backends/webgpu/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ python_unittest(
],
)

python_unittest(
name = "test_index",
srcs = [
"ops/index/test_index.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:lib",
],
)

runtime.python_library(
name = "tester",
srcs = ["tester.py"],
Expand Down
Loading
Loading