XNNPACK: Lift constant mul scalars for partitioning#20515
Conversation
XNNPACK supports the tensor overload for multiply, but plain aten.mul.Scalar is not selected by the XNNPACK partitioner. Adds a narrow scalar-lifting pass that rewrites aten.mul.Scalar to aten.mul.Tensor by registering the scalar as a small constant buffer. This avoids introducing an aten.full op while allowing the existing multiply lowering path to partition the op. Keep SDPA scale multipliers as aten.mul.Scalar so ConvertToSDPAPass can still recover the attention scale before replacing the pattern. Add test coverage for that guard. Allow the XNNPACK tester to pass transform passes through to to_edge_transform_and_lower. This keeps op tests on the same path as existing XNNPACK model tests that already use explicit transform passes. For DeIT Tiny, this removes 24 portable aten.mul.Scalar nodes and reduces delegate count from 62 to 50. In current local timing checks the latency impact is modest: about 1% faster on both Android SME2 and the aarch64 XNNPACK/KleidiAI NEON-class host runner. These are modest uplifts but may introduce more opportunities for improvements. Signed-off-by: Måns Nilsson <mans.nilsson@arm.com> Change-Id: I83b6ad53925edb72afdf0077b5dbb99b5d9c4648
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20515
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 4 Unrelated Failures, 1 Unclassified FailureAs of commit 2861e10 with merge base 82a605d ( NEW FAILURES - The following jobs have failed:
UNCLASSIFIED FAILURE - DrCI could not classify the following job because the workflow did not run on the merge base. The failure may be pre-existing on trunk or introduced by this PR:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review this PR |
|
Claude finished @mansnils's task in 9m 47s —— View job PR Review: XNNPACK — Lift constant mul scalars for partitioningReviewed the diff against Correctness / design
Tester change
Minor
Tests
Nothing here is blocking on its own; #1 and #2 are the items I'd most want addressed (or explicitly reasoned about in a comment) before merge. |
There was a problem hiding this comment.
Pull request overview
This PR improves XNNPACK partitioning coverage by rewriting constant aten.mul.Scalar occurrences into aten.mul.Tensor with a lifted constant buffer, enabling existing multiply lowering to pick them up without introducing runtime aten.full. It also extends the XNNPACK test harness to allow passing explicit transform passes into to_edge_transform_and_lower, and adds tests to protect SDPA scale handling.
Changes:
- Add
LiftConstantScalarOperandsPassto lift constant scalar operands (currently focused onmul.Scalar) into small tensor buffers to enable XNNPACK partitioning. - Enable XNNPACK
Tester.to_edge_transform_and_lower(...)to threadtransform_passesthrough toexecutorch.exir.to_edge_transform_and_lower. - Add unit and op-test coverage for scalar lifting and for the SDPA scale guard behavior.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/xnnpack/utils/configs.py | Adds the scalar-lifting pass to the default XNNPACK transform pass pipeline. |
| backends/xnnpack/test/tester/tester.py | Extends the XNNPACK tester stage to forward explicit transform_passes into to_edge_transform_and_lower. |
| backends/xnnpack/test/passes/test_lift_constant_scalar_operands_pass.py | New unit tests validating scalar lifting and guarding SDPA scale multipliers. |
| backends/xnnpack/test/ops/test_multiply.py | Adds an op test to ensure mul-scalar gets delegated via the transform-pass path. |
| backends/xnnpack/_passes/lift_constant_scalar_operands_pass.py | Implements the scalar-to-tensor lifting pass with an SDPA-specific preservation guard. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def __init__( | ||
| self, | ||
| partitioners: Optional[List[Partitioner]] = None, | ||
| edge_compile_config: Optional[EdgeCompileConfig] = None, | ||
| transform_passes: Optional[List[PassType]] = None, | ||
| ): |
| def to_edge_transform_and_lower( | ||
| self, | ||
| to_edge_and_lower_stage: Optional[BaseStages.ToEdgeTransformAndLower] = None, | ||
| generate_etrecord: bool = False, | ||
| *, | ||
| partitioners: Optional[List[Partitioner]] = None, | ||
| edge_compile_config: Optional[EdgeCompileConfig] = None, | ||
| transform_passes: Optional[List[PassType]] = None, | ||
| ): | ||
| if to_edge_and_lower_stage is None: | ||
| to_edge_and_lower_stage = ToEdgeTransformAndLower( | ||
| partitioners=partitioners, | ||
| edge_compile_config=edge_compile_config, | ||
| transform_passes=transform_passes, | ||
| ) | ||
| else: | ||
| if partitioners is not None: | ||
| to_edge_and_lower_stage.partitioners = partitioners | ||
| if edge_compile_config is not None: | ||
| to_edge_and_lower_stage.edge_compile_conf = edge_compile_config | ||
| return super().to_edge_transform_and_lower( | ||
| to_edge_and_lower_stage, | ||
| generate_etrecord=generate_etrecord, | ||
| ) |
| from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config | ||
| from executorch.exir import EdgeCompileConfig | ||
| from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower | ||
| from executorch.exir.backend.partitioner import Partitioner | ||
| from torch._export.pass_base import PassType | ||
| from torch.export import ExportedProgram | ||
| from torchao.quantization.pt2e.quantizer import Quantizer |
| torch.export.export(module, (torch.randn(2, 3),), strict=True), | ||
| compile_config=get_xnnpack_edge_compile_config(skip_dim_order=True), | ||
| ) | ||
| return edge.transform([LiftConstantScalarOperandsPass()]).exported_program() |
XNNPACK supports the tensor overload for multiply, but plain aten.mul.Scalar is not selected by the XNNPACK partitioner.
Adds a narrow scalar-lifting pass that rewrites aten.mul.Scalar to aten.mul.Tensor by registering the scalar as a small constant buffer. This avoids introducing an aten.full op while allowing the existing multiply lowering path to partition the op.
Keep SDPA scale multipliers as aten.mul.Scalar so ConvertToSDPAPass can still recover the attention scale before replacing the pattern. Add test coverage for that guard.
Allow the XNNPACK tester to pass transform passes through to to_edge_transform_and_lower. This keeps op tests on the same path as existing XNNPACK model tests that already use explicit transform passes.
For DeIT Tiny, this removes 24 portable aten.mul.Scalar nodes and reduces delegate count from 62 to 50. In current local timing checks the latency impact is modest: about 1% faster on both Android SME2 and the aarch64 XNNPACK/KleidiAI NEON-class host runner. These are modest uplifts but may introduce more opportunities for improvements.
cc @GregoryComer @digantdesai @cbilgin @freddan80 @per @zingo @oscarandersson8218 @Sebastian-Larsson @robell @rascani