From e924d074ba33732c91dc376d02e336fdc29cae1d Mon Sep 17 00:00:00 2001 From: Jake Stevens Date: Thu, 25 Jun 2026 12:10:39 -0700 Subject: [PATCH] Allow context-binary lowering to use edge dialect ops (#20518) Summary: Unblocks the QNN context-binary path from lowering through `to_edge` with `_use_edge_ops=True` (the default). Previously it was pinned to `EdgeCompileConfig(_use_edge_ops=False)` purely to keep the `qaisw` context-loader custom op's original name, because loader detection was name-based. Loader detection now goes through a single `is_context_loader_target()` helper that matches the op namespace (so it works on the edge-dialect wrapper), replacing the three name-dependent checks (`eval`, raw `.namespace`, substring). Differential Revision: D109598309 --- backends/qualcomm/builders/qnn_constants.py | 30 +++++ .../qualcomm/partition/qnn_partitioner.py | 4 +- backends/qualcomm/qnn_preprocess.py | 19 ++- backends/qualcomm/tests/test_passes.py | 116 +++++++++++++++++- backends/qualcomm/utils/utils.py | 11 +- 5 files changed, 162 insertions(+), 18 deletions(-) diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index d1f0d3fff00..4a05b217334 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -7,6 +7,10 @@ from dataclasses import dataclass from enum import IntEnum, unique +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.operator.convert import parse_qualified_opname, unwrap_op_overload +from torch._ops import OpOverload + QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw" # Below constants should be same as those in QNN headers. @@ -57,6 +61,32 @@ class OpContextLoader: meta_ctx_bin: str = "qnn_context_binary" +ContextLoaderTarget = EdgeOpOverload | OpOverload + + +def is_context_loader_target( + target: ContextLoaderTarget, + op_name: str | None = None, +) -> bool: + namespace, name = parse_qualified_opname( + str(unwrap_op_overload(target)._schema.name) + ) + if namespace != OpContextLoader.namespace: + return False + if op_name is None: + return True + return name == op_name + + +def is_context_loader_node(node: object, op_name: str | None = None) -> bool: + if getattr(node, "op", None) != "call_function": + return False + target = getattr(node, "target", None) + if not isinstance(target, (EdgeOpOverload, OpOverload)): + return False + return is_context_loader_target(target, op_name) + + @dataclass(init=False, frozen=True) class OpConv2d: op_name: str = "Conv2d" diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index ce48b8bd949..2e76d8c13a2 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -10,7 +10,7 @@ import torch from executorch.backends.qualcomm.builders import node_visitor_manager -from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader +from executorch.backends.qualcomm.builders.qnn_constants import is_context_loader_node from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, @@ -95,7 +95,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: if ( node.target in allow_list_operator # bypass if custom op appears - or OpContextLoader.namespace == node.target.namespace + or is_context_loader_node(node) # bypass dequantize op for parameters & buffers or node.meta.get(QCOM_BYPASS_NODE, False) ): diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index cbe96b5954a..0bdf36c1e00 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -13,7 +13,10 @@ get_qnn_pass_manager_cls, ) from executorch.backends.qualcomm.builders.node_visitor_manager import get_node_visitors -from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader +from executorch.backends.qualcomm.builders.qnn_constants import ( + is_context_loader_node, + OpContextLoader, +) from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, @@ -89,16 +92,12 @@ def _build_op_wrappers( f"For {node}, {node.op}:{node.target.__name__} " "is not supported in Qnn Delegate" ) - try: - context_loader_target = eval( - f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}", - globals().update(torch.__dict__), - ) - assert node.target == context_loader_target, err_msg - # if graph has context binary loader node, return directly + if ( + is_context_loader_node(node) + and OpContextLoader.meta_ctx_bin in node.meta + ): return node.meta[OpContextLoader.meta_ctx_bin] - except: - raise RuntimeError(err_msg) + raise RuntimeError(err_msg) elif node.op in [ "get_attr", diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 1124b01d613..b1fbc5f7511 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -13,6 +13,13 @@ from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_qnn_pass_manager_cls, ) +from executorch.backends.qualcomm.builders.qnn_constants import ( + is_context_loader_node, + is_context_loader_target, + OpContextLoader, +) +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnOperatorSupport +from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import ( QcomChipset, @@ -28,13 +35,120 @@ generate_qnn_executorch_compiler_spec, to_edge_transform_and_lower_to_qnn, ) -from executorch.exir import to_edge +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY from executorch.exir.dialects._ops import ops as exir_ops +from torch.library import Library from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class TestPasses(unittest.TestCase): + def test_context_loader_edge_op_is_delegated(self): + op_name = "ctx_loader_delegation" + graph_name = "forward" + ctx_bin = b"qnn_context_binary" + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") + self.addCleanup(custom_op._destroy) + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") + + @torch.library.impl( + custom_op, op_name, dispatch_key="CompositeExplicitAutograd" + ) + def op_impl(inputs): + return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),) + + class Model(torch.nn.Module): + def forward(self, x): + return getattr( + getattr(torch.ops, OpContextLoader.namespace), op_name + ).default((x,)) + + exported_program = torch.export.export( + Model(), (torch.ones(1, 2),), strict=True + ) + edge_program_manager = to_edge( + {graph_name: exported_program}, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + context_loader_nodes = [ + node + for node in edge_program_manager._edge_programs[graph_name].graph.nodes + if is_context_loader_node(node, op_name) + ] + self.assertEqual(1, len(context_loader_nodes)) + self.assertTrue(is_context_loader_node(context_loader_nodes[0])) + context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin] = ctx_bin + self.assertEqual( + ctx_bin, + context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin], + ) + + support = QnnOperatorSupport.__new__(QnnOperatorSupport) + support.phase = "QnnPartitioner" + self.assertTrue(support.is_node_supported(None, context_loader_nodes[0])) + + def test_is_context_loader_target_predicate(self): + op_name = "ctx_loader_predicate" + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") + self.addCleanup(custom_op._destroy) + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") + + # Plain OpOverload in the context-loader namespace must match (the + # _op unwrap must not break the non-edge-dialect target case). + qaisw_op = getattr( + getattr(torch.ops, OpContextLoader.namespace), op_name + ).default + self.assertTrue(is_context_loader_target(qaisw_op, op_name)) + self.assertFalse(is_context_loader_target(qaisw_op, "different_op")) + + # Ops in other namespaces must not match, including an edge op + # (unwrapped via _op) whose namespace is not the loader's. + self.assertFalse(is_context_loader_target(torch.ops.aten.add.default)) + self.assertFalse(is_context_loader_target(exir_ops.edge.aten.add.Tensor)) + + def test_build_op_wrappers_returns_context_binary(self): + op_name = "ctx_loader_build" + graph_name = "forward" + ctx_bin = b"qnn_context_binary" + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") + self.addCleanup(custom_op._destroy) + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") + + @torch.library.impl( + custom_op, op_name, dispatch_key="CompositeExplicitAutograd" + ) + def op_impl(inputs): + return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),) + + class Model(torch.nn.Module): + def forward(self, x): + return getattr( + getattr(torch.ops, OpContextLoader.namespace), op_name + ).default((x,)) + + exported_program = torch.export.export( + Model(), (torch.ones(1, 2),), strict=True + ) + edge_program = to_edge( + {graph_name: exported_program}, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + )._edge_programs[graph_name] + for node in edge_program.graph.nodes: + if is_context_loader_node(node, op_name): + node.meta[OpContextLoader.meta_ctx_bin] = ctx_bin + + # For a graph whose only op is the context-binary loader, _build_op_wrappers + # returns the stamped context binary directly, before any QNN compilation. + result = QnnBackend._build_op_wrappers( + edge_program, + enable_tensor_dump=False, + op_package_infos=[], + use_mha2sha=False, + backend_type=QnnExecuTorchBackendType.kHtpBackend, + ) + self.assertEqual(ctx_bin, result) + def _build_quantized_graph(self): """Build a quantized graph through AnnotateQuantAttrs + FoldQDQ.""" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 16a071f8cf0..04ee7107141 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -26,7 +26,10 @@ QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP, ) -from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader +from executorch.backends.qualcomm.builders.qnn_constants import ( + is_context_loader_node, + OpContextLoader, +) from executorch.backends.qualcomm.partition.qnn_partitioner import ( generate_qnn_executorch_option, get_skip_decomp_table, @@ -959,15 +962,13 @@ def preprocess_binary(ctx_bin, compiler_specs): # temporarily remove the first parameter name. edge_prog_mgr = to_edge( {graph_name: bundle_prog["exported_program"]}, - # do not alter name for custom op - compile_config=EdgeCompileConfig(_use_edge_ops=False), + compile_config=EdgeCompileConfig(_check_ir_validity=False), ) # update meta with context binary for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes: - if n.op == "call_function" and OpContextLoader.namespace in str(n.target): + if is_context_loader_node(n, op_name): n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin - break bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend( QnnPartitioner(compiler_specs)