Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ def set_global(
quantization_config, node_finder, self.pattern_matcher
)
self.global_config = quantization_config
self.shared_qspec_quantizer.global_config = quantization_config
return self

def set_node_target(
Expand Down
41 changes: 38 additions & 3 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser):
def __init__(self, targets: Optional[list[Callable[..., object]]] = None) -> None:
super().__init__()
QuantizerReporterUser.__init__(self)
self.global_config: Optional[QuantizationConfig] = None
if targets is None:
self.targets = self.SHARED_QSPEC_OPS_DEFAULT
self.support_config_path = (
Expand Down Expand Up @@ -551,10 +552,24 @@ def _append_input_qspec(
return
adjacent_qspecs.append(input_qspec)

def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]:
def _is_quantized_io_boundary(self, node: Node) -> bool:
"""Return True if node is a model input/output annotated by the
quantizer.

Such a node sits on the quantized interface but its qspec is often
filtered out of shared-cluster propagation: a uint8 IO qspec is skipped
by _skip_shared_qspec_from_io, and an input-state placeholder may carry
an annotation with no output_qspec. Its presence still signals that the
cluster is on the quantized data path.

"""
return node.op in ("placeholder", "output") and self._is_annotated(node)

def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any], bool]:
shared_nodes = set()
bfs_queue = [root_node]
adjacent_qspecs: list[Any] = []
touches_quantized_io = False

while bfs_queue:
node = bfs_queue.pop(0)
Expand All @@ -563,12 +578,14 @@ def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]:
for input_node in node.all_input_nodes:
self._maybe_enqueue_shared_node(input_node, shared_nodes, bfs_queue)
self._append_output_qspec(input_node, adjacent_qspecs)
touches_quantized_io |= self._is_quantized_io_boundary(input_node)

for output_node in node.users.keys():
self._maybe_enqueue_shared_node(output_node, shared_nodes, bfs_queue)
self._append_input_qspec(output_node, node, adjacent_qspecs)
touches_quantized_io |= self._is_quantized_io_boundary(output_node)

return shared_nodes, adjacent_qspecs
return shared_nodes, adjacent_qspecs, touches_quantized_io

def _should_skip_while_shared_qspec(self, node: Node) -> bool:
return node.target == torch.ops.higher_order.while_loop and bool(
Expand Down Expand Up @@ -623,7 +640,25 @@ def _annotate_shared_cluster(self, root_node: Node) -> None:
)
return

shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node)
shared_nodes, adjacent_qspecs, touches_quantized_io = self._get_shared_clique(
root_node
)

# If there is no neighbor qspec to propagate but the cluster sits on the
# quantized I/O boundary (e.g. a state-passthrough cat whose only neighbors
# are a uint8 model input skipped by _skip_shared_qspec_from_io and an
# input-state placeholder with no output_qspec), initiate quantization from
# the global config rather than leaving the cluster in float. Without this,
# such clusters fall off the integer delegate onto CPU.
if (
len(adjacent_qspecs) == 0
and touches_quantized_io
and self.global_config is not None
):
global_input_qspec = self.global_config.get_input_act_qspec()
if global_input_qspec is not None:
adjacent_qspecs = [global_input_qspec]

node_order = {node: index for index, node in enumerate(root_node.graph.nodes)}
ordered_nodes = sorted(shared_nodes, key=lambda node: node_order.get(node, 0))

Expand Down
54 changes: 54 additions & 0 deletions backends/arm/test/quantizer/test_uint8_io_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import torch

from executorch.backends.arm.quantizer import (
get_symmetric_quantization_config,
get_uint8_io_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


class SimpleMLP(torch.nn.Module):
Expand All @@ -24,6 +27,21 @@ def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))


class CloneAtIoBoundary(torch.nn.Module):
"""Zero-arithmetic cluster whose only adjacent annotated neighbours are
uint8-annotated IO nodes (input placeholder + graph output).

With set_global(int8) + set_io(uint8), both the placeholder and the output
node carry uint8 qspecs that _skip_shared_qspec_from_io filters out, leaving
adjacent_qspecs empty. Before the IO-boundary fallback fix in
SharedQspecQuantizer, this caused the cluster to stay in float.

"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.clone(x)


def test_uint8_io_quantization_config_tosa_INT_applies_to_io():
model = SimpleMLP().eval()
test_data = (torch.rand(1, 4),)
Expand All @@ -40,3 +58,39 @@ def test_uint8_io_quantization_config_tosa_INT_applies_to_io():
output_qspecs={io_config.output_activation: 1},
)
pipeline.run()


def test_io_boundary_shared_cluster_is_quantized():
"""Regression: a zero-arithmetic cluster adjacent only to uint8-annotated IO
nodes must be annotated with the global int8 qspec, not left in float.

_skip_shared_qspec_from_io filters the uint8 qspec from IO nodes, so when
the cluster's only neighbours are such nodes adjacent_qspecs ends up empty.
The fix in SharedQspecQuantizer detects the IO-boundary via
_is_quantized_io_boundary and falls back to global_config.get_input_act_qspec().
"""
model = CloneAtIoBoundary().eval()
test_data = (torch.rand(1, 4),)
compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT")

quantizer = TOSAQuantizer(compile_spec, use_composable_quantizer=True)
quantizer.set_global(get_symmetric_quantization_config())
quantizer.set_io(get_uint8_io_quantization_config())

exported = torch.export.export(model, test_data, strict=True)
prepared = prepare_pt2e(exported.module(), quantizer)

clone_nodes = [
n
for n in prepared.graph.nodes
if n.op == "call_function" and n.target == torch.ops.aten.clone.default
]
assert len(clone_nodes) == 1, f"Expected 1 clone node, got {len(clone_nodes)}"
clone_node = clone_nodes[0]

assert (
Q_ANNOTATION_KEY in clone_node.meta
), "clone node was not annotated — IO-boundary cluster stayed in float"
assert (
clone_node.meta[Q_ANNOTATION_KEY].output_qspec is not None
), "clone node has no output_qspec — IO-boundary cluster stayed in float"
Loading