diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3508410509c..0080d77ab69 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -1220,6 +1220,7 @@ def set_global( quantization_config, node_finder, self.pattern_matcher ) self.global_config = quantization_config + self.shared_qspec_quantizer.global_config = quantization_config return self def set_node_target( diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index d4c2dfebdee..a59ccff87b1 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -480,6 +480,7 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser): def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> None: super().__init__() QuantizerReporterUser.__init__(self) + self.global_config: Optional[QuantizationConfig] = None if targets is None: self.targets = self.SHARED_QSPEC_OPS_DEFAULT self.support_config_path = ( @@ -551,10 +552,24 @@ def _append_input_qspec( return adjacent_qspecs.append(input_qspec) - def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]: + def _is_quantized_io_boundary(self, node: Node) -> bool: + """Return True if node is a model input/output annotated by the + quantizer. + + Such a node sits on the quantized interface but its qspec is often + filtered out of shared-cluster propagation: a uint8 IO qspec is skipped + by _skip_shared_qspec_from_io, and an input-state placeholder may carry + an annotation with no output_qspec. Its presence still signals that the + cluster is on the quantized data path. + + """ + return node.op in ("placeholder", "output") and self._is_annotated(node) + + def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any], bool]: shared_nodes = set() bfs_queue = [root_node] adjacent_qspecs: list[Any] = [] + touches_quantized_io = False while bfs_queue: node = bfs_queue.pop(0) @@ -563,12 +578,14 @@ def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]: for input_node in node.all_input_nodes: self._maybe_enqueue_shared_node(input_node, shared_nodes, bfs_queue) self._append_output_qspec(input_node, adjacent_qspecs) + touches_quantized_io |= self._is_quantized_io_boundary(input_node) for output_node in node.users.keys(): self._maybe_enqueue_shared_node(output_node, shared_nodes, bfs_queue) self._append_input_qspec(output_node, node, adjacent_qspecs) + touches_quantized_io |= self._is_quantized_io_boundary(output_node) - return shared_nodes, adjacent_qspecs + return shared_nodes, adjacent_qspecs, touches_quantized_io def _should_skip_while_shared_qspec(self, node: Node) -> bool: return node.target == torch.ops.higher_order.while_loop and bool( @@ -623,7 +640,25 @@ def _annotate_shared_cluster(self, root_node: Node) -> None: ) return - shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node) + shared_nodes, adjacent_qspecs, touches_quantized_io = self._get_shared_clique( + root_node + ) + + # If there is no neighbor qspec to propagate but the cluster sits on the + # quantized I/O boundary (e.g. a state-passthrough cat whose only neighbors + # are a uint8 model input skipped by _skip_shared_qspec_from_io and an + # input-state placeholder with no output_qspec), initiate quantization from + # the global config rather than leaving the cluster in float. Without this, + # such clusters fall off the integer delegate onto CPU. + if ( + len(adjacent_qspecs) == 0 + and touches_quantized_io + and self.global_config is not None + ): + global_input_qspec = self.global_config.get_input_act_qspec() + if global_input_qspec is not None: + adjacent_qspecs = [global_input_qspec] + node_order = {node: index for index, node in enumerate(root_node.graph.nodes)} ordered_nodes = sorted(shared_nodes, key=lambda node: node_order.get(node, 0)) diff --git a/backends/arm/test/quantizer/test_uint8_io_quantization.py b/backends/arm/test/quantizer/test_uint8_io_quantization.py index 7461ca85a6f..3b839dc01c0 100644 --- a/backends/arm/test/quantizer/test_uint8_io_quantization.py +++ b/backends/arm/test/quantizer/test_uint8_io_quantization.py @@ -6,11 +6,14 @@ import torch from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, get_uint8_io_quantization_config, TOSAQuantizer, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY class SimpleMLP(torch.nn.Module): @@ -24,6 +27,21 @@ def forward(self, x): return self.fc2(self.relu(self.fc1(x))) +class CloneAtIoBoundary(torch.nn.Module): + """Zero-arithmetic cluster whose only adjacent annotated neighbours are + uint8-annotated IO nodes (input placeholder + graph output). + + With set_global(int8) + set_io(uint8), both the placeholder and the output + node carry uint8 qspecs that _skip_shared_qspec_from_io filters out, leaving + adjacent_qspecs empty. Before the IO-boundary fallback fix in + SharedQspecQuantizer, this caused the cluster to stay in float. + + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.clone(x) + + def test_uint8_io_quantization_config_tosa_INT_applies_to_io(): model = SimpleMLP().eval() test_data = (torch.rand(1, 4),) @@ -40,3 +58,39 @@ def test_uint8_io_quantization_config_tosa_INT_applies_to_io(): output_qspecs={io_config.output_activation: 1}, ) pipeline.run() + + +def test_io_boundary_shared_cluster_is_quantized(): + """Regression: a zero-arithmetic cluster adjacent only to uint8-annotated IO + nodes must be annotated with the global int8 qspec, not left in float. + + _skip_shared_qspec_from_io filters the uint8 qspec from IO nodes, so when + the cluster's only neighbours are such nodes adjacent_qspecs ends up empty. + The fix in SharedQspecQuantizer detects the IO-boundary via + _is_quantized_io_boundary and falls back to global_config.get_input_act_qspec(). + """ + model = CloneAtIoBoundary().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + quantizer = TOSAQuantizer(compile_spec, use_composable_quantizer=True) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + + exported = torch.export.export(model, test_data, strict=True) + prepared = prepare_pt2e(exported.module(), quantizer) + + clone_nodes = [ + n + for n in prepared.graph.nodes + if n.op == "call_function" and n.target == torch.ops.aten.clone.default + ] + assert len(clone_nodes) == 1, f"Expected 1 clone node, got {len(clone_nodes)}" + clone_node = clone_nodes[0] + + assert ( + Q_ANNOTATION_KEY in clone_node.meta + ), "clone node was not annotated — IO-boundary cluster stayed in float" + assert ( + clone_node.meta[Q_ANNOTATION_KEY].output_qspec is not None + ), "clone node has no output_qspec — IO-boundary cluster stayed in float"