diff --git a/backends/transforms/quantize_fused_convbn_bias_pass.py b/backends/transforms/quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..f1c599e05ba --- /dev/null +++ b/backends/transforms/quantize_fused_convbn_bias_pass.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch import fx +from torch._export.utils import ( + get_buffer, + get_lifted_tensor_constant, + get_param, + is_lifted_tensor_constant, + is_param, +) +from torch._guards import detect_fake_mode +from torch.export.exported_program import InputKind, InputSpec, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +# --- ExportedProgram param helpers --- + + +def _set_param_ep(exported_program, node_or_name, tensor, insert_before=None): + """Set or create a parameter in an exported program. + + If node_or_name is a Node, updates the existing parameter or constant value. + If node_or_name is a string, creates a new parameter placeholder. + """ + fake_mode = detect_fake_mode( + tuple( + node.meta["val"] + for node in exported_program.graph.nodes + if node.op == "placeholder" + ) + ) + + if isinstance(node_or_name, fx.Node): + node = node_or_name + if node.name in exported_program.graph_signature.inputs_to_parameters: + name = exported_program.graph_signature.inputs_to_parameters[node.name] + exported_program.state_dict[name] = torch.nn.Parameter( + tensor, requires_grad=False + ) + elif ( + node.name + in exported_program.graph_signature.inputs_to_lifted_tensor_constants + ): + name = exported_program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + exported_program.constants[name] = tensor + else: + raise ValueError( + f"Node {node.name} is not a parameter or lifted tensor constant" + ) + node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + node.meta["val"].constant = tensor + return node + + # Create a new parameter from string name + name = node_or_name + graph = exported_program.graph_module.graph + placeholders = [n for n in graph.nodes if n.op == "placeholder"] + input_name = f"arg_{name}" + with graph.inserting_before(placeholders[0]): + new_placeholder = graph.placeholder(input_name) + exported_program.graph_signature.input_specs.insert( + 0, + InputSpec( + kind=InputKind.PARAMETER, + arg=TensorArgument(name=input_name), + target=name, + persistent=None, + ), + ) + exported_program.state_dict[name] = torch.nn.Parameter(tensor, requires_grad=False) + new_placeholder.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + new_placeholder.meta["val"].constant = tensor + return new_placeholder + + +def _get_bias_tensor_ep(exported_program, bias_node): + """Extract bias tensor from parameter or lifted constant in an ExportedProgram.""" + if is_param(exported_program, bias_node): + return get_param(exported_program, bias_node) + elif is_lifted_tensor_constant(exported_program, bias_node): + return get_lifted_tensor_constant(exported_program, bias_node) + return None + + +# --- GraphModule param helpers --- + + +def _get_tensor_from_node(graph_module, node): + """Get tensor from a get_attr node on a GraphModule.""" + if node is None or node.op != "get_attr": + return None + target_atoms = node.target.split(".") + attr = graph_module + for atom in target_atoms: + if not hasattr(attr, atom): + return None + attr = getattr(attr, atom) + return attr + + +def _set_param_gm(graph_module, node_or_name, tensor, insert_before=None): + """Set or create a parameter on a GraphModule using get_attr nodes. + + If node_or_name is a Node, updates the existing parameter tensor. + If node_or_name is a string, creates a new get_attr node. + """ + if isinstance(node_or_name, fx.Node): + node = node_or_name + target_atoms = node.target.split(".") + parent = graph_module + for atom in target_atoms[:-1]: + parent = getattr(parent, atom) + setattr( + parent, + target_atoms[-1], + torch.nn.Parameter(tensor, requires_grad=False), + ) + if "val" in node.meta: + fake_mode = detect_fake_mode( + tuple( + n.meta["val"] + for n in graph_module.graph.nodes + if n.op == "placeholder" and "val" in n.meta + ) + ) + if fake_mode is not None: + node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + else: + node.meta["val"] = tensor + return node + + # Create new get_attr node + name = node_or_name + graph_module.register_parameter( + name, torch.nn.Parameter(tensor, requires_grad=False) + ) + with graph_module.graph.inserting_before(insert_before): + new_node = graph_module.graph.get_attr(name) + fake_mode = detect_fake_mode( + tuple( + n.meta["val"] + for n in graph_module.graph.nodes + if n.op == "placeholder" and "val" in n.meta + ) + ) + if fake_mode is not None: + new_node.meta["val"] = fake_mode.from_tensor(tensor, static_shapes=True) + else: + new_node.meta["val"] = tensor + return new_node + + +# --- Shared core logic --- + + +def _quantize_fused_conv_bias( + graph_module, + conv_targets, + unsqueeze_targets, + dq_per_tensor, + dq_per_channel, + get_bias_tensor, + set_param, + get_weight_scale_tensor, + default_zero_bias=False, +): + """Core logic for quantizing biases introduced by BatchNorm fusion/QAT. + + BatchNorm fusion or QAT introduces a bias to conv layers that originally had + bias=False. Since the bias is added after the quantizer runs, it lacks proper + quantize->dequantize nodes. This function adds them. + + Args: + graph_module: The graph module to transform. + conv_targets: Tuple of conv op targets to match. + unsqueeze_targets: Tuple of unsqueeze op targets to unwrap. + dq_per_tensor: The dequantize_per_tensor op for this dialect. + dq_per_channel: The dequantize_per_channel op for this dialect. + get_bias_tensor: Callable(node) -> Optional[Tensor]. + set_param: Callable(node_or_name, tensor, insert_before=None) -> Node. + get_weight_scale_tensor: Callable(node) -> Tensor. + default_zero_bias: If True, create zero bias for conv nodes without bias. + + Returns: + True if any modifications were made. + """ + modified = False + for node in graph_module.graph.nodes: + if node.target not in conv_targets: + continue + + input_dequant = node.args[0] + weight_dequant = node.args[1] + bias_node = node.args[2] if len(node.args) > 2 else None + + if bias_node is None: + if default_zero_bias: + channel = node.meta["val"].shape[1] + bias_node = set_param( + node.name + "_default_zero_bias", + torch.zeros(channel), + insert_before=node, + ) + args = list(node.args) + if len(args) < 3: + args.append(bias_node) + else: + args[2] = bias_node + node.args = tuple(args) + else: + continue + + bias = get_bias_tensor(bias_node) + if bias is None or bias.dtype == torch.int32: + continue + + if input_dequant.target in unsqueeze_targets: + input_dequant = input_dequant.args[0] + + assert ( + input_dequant.target == dq_per_tensor + ), f"Expected dequantize_per_tensor, got {input_dequant.target}" + + bias_val = bias_node.meta.get("val") + dequant_val = ( + bias_val.to(torch.float32) + if bias_val is not None + else torch.empty(bias.shape, dtype=torch.float32) + ) + + if isinstance(weight_dequant.args[1], torch.fx.node.Node): + weight_scale = get_weight_scale_tensor(weight_dequant.args[1]) + bias_scale = input_dequant.args[1] * weight_scale + + bias_zp = torch.zeros(bias_scale.shape, dtype=torch.int32) + qbias = torch.ops.quantized_decomposed.quantize_per_channel.default( + bias, + bias_scale, + bias_zp, + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ) + set_param(bias_node, qbias) + + scale_node = set_param( + node.name + "_bias_scale", bias_scale, insert_before=node + ) + zp_node = set_param( + node.name + "_bias_zero_point", bias_zp, insert_before=node + ) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_channel, + ( + bias_node, + scale_node, + zp_node, + 0, + -(2**31), + 2**31 - 1, + torch.int32, + ), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + else: + weight_scale = weight_dequant.args[1] + bias_scale = input_dequant.args[1] * weight_scale + + qbias = torch.ops.quantized_decomposed.quantize_per_tensor.default( + bias, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32 + ) + set_param(bias_node, qbias) + + with graph_module.graph.inserting_before(node): + bias_dequant = graph_module.graph.call_function( + dq_per_tensor, + (bias_node, bias_scale, 0, -(2**31), 2**31 - 1, torch.int32), + ) + bias_dequant.meta["val"] = dequant_val + node.replace_input_with(bias_node, bias_dequant) + + modified = True + + graph_module.recompile() + return modified + + +class QuantizeFusedConvBnBiasAtenPass(PassBase): + """Quantize biases introduced by BatchNorm fusion/QAT on aten dialect graphs. + + Operates on a GraphModule. If the graph_module came from an ExportedProgram + (params are placeholder nodes), pass the exported_program so params can be + resolved. If operating on a plain GraphModule (params are get_attr nodes), + exported_program can be omitted. + """ + + def __init__(self, exported_program=None, default_zero_bias=False) -> None: + self.exported_program = exported_program + self.default_zero_bias = default_zero_bias + + def call(self, graph_module: fx.GraphModule) -> PassResult: + ep = self.exported_program + if ep is not None: + + def get_bias(node): + return _get_bias_tensor_ep(ep, node) + + def set_param(n, t, insert_before=None): + return _set_param_ep(ep, n, t) + + def get_scale(node): + return get_buffer(ep, node) + + else: + + def get_bias(node): + return _get_tensor_from_node(graph_module, node) + + def set_param(n, t, insert_before=None): + return _set_param_gm(graph_module, n, t, insert_before) + + def get_scale(node): + return _get_tensor_from_node(graph_module, node) + + modified = _quantize_fused_conv_bias( + graph_module, + conv_targets=( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose2d.input, + ), + unsqueeze_targets=( + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.unsqueeze.default, + ), + dq_per_tensor=torch.ops.quantized_decomposed.dequantize_per_tensor.default, + dq_per_channel=torch.ops.quantized_decomposed.dequantize_per_channel.default, + get_bias_tensor=get_bias, + set_param=set_param, + get_weight_scale_tensor=get_scale, + default_zero_bias=self.default_zero_bias, + ) + return PassResult(graph_module, modified) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 3cda58f6426..463c89e43b2 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -204,6 +204,31 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "quantize_fused_convbn_bias_pass", + srcs = ["quantize_fused_convbn_bias_pass.py"], + visibility = ["PUBLIC"], + deps = [ + "//caffe2:torch", + ], + ) + + runtime.python_test( + name = "test_quantize_fused_convbn_bias_pass", + srcs = [ + "test/test_quantize_fused_convbn_bias_pass.py", + ], + deps = [ + "//caffe2:torch", + ":quantize_fused_convbn_bias_pass", + "//executorch/backends/arm/quantizer:lib", + "//executorch/backends/arm/test:common", + "//executorch/backends/arm/tosa:tosa", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "fbsource//third-party/pypi/pytest:pytest", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [ diff --git a/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py new file mode 100644 index 00000000000..f8d0269630b --- /dev/null +++ b/backends/transforms/test/test_quantize_fused_convbn_bias_pass.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree.. + +import sys +from typing import Tuple +from unittest.mock import MagicMock + +# Stub modules that are transitively imported by arm_quantizer but never +# exercised by these tests. +for _mod in ("tosa_serializer", "tosa", "tosa.TosaGraph"): + if _mod not in sys.modules: + try: + __import__(_mod) + except ModuleNotFoundError: + sys.modules[_mod] = MagicMock() + +import pytest +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + QuantizeFusedConvBnBiasAtenPass, +) +from torch import nn +from torch.export import export +from torchao.quantization.pt2e import move_exported_model_to_eval +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e + + +input_t = Tuple[torch.Tensor] + + +class ConvBnNoBias(nn.Module): + """Conv2d with bias=False followed by BatchNorm. QAT fusion introduces a bias.""" + + def __init__(self, per_channel: bool = True) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +class ConvBnReluNoBias(nn.Module): + """Conv2d with bias=False, BatchNorm, and ReLU.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + self.bn = nn.BatchNorm2d(16) + self.relu = nn.ReLU() + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.bn(self.conv(x))) + + +class Conv1dBnNoBias(nn.Module): + """Conv1d with bias=False followed by BatchNorm.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv1d(3, 8, kernel_size=3, bias=False) + self.bn = nn.BatchNorm1d(8) + + def get_inputs(self) -> input_t: + return (torch.randn(2, 3, 16),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(self.conv(x)) + + +class ConvNoBnNoBias(nn.Module): + """Conv2d with bias=False and no BatchNorm. Pass should skip this.""" + + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 16, kernel_size=3, bias=False) + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +# --- Shared helpers --- + + +def _qat_prepare_convert(model, per_channel): + """QAT prepare -> calibrate -> convert_pt2e, returns GraphModule with get_attr nodes.""" + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + quantizer.set_global( + get_symmetric_quantization_config(is_qat=True, is_per_channel=per_channel) + ) + example_input = model.get_inputs() + exported = export(model, example_input, strict=True).module() + prepared = prepare_qat_pt2e(exported, quantizer) + prepared(*example_input) + move_exported_model_to_eval(prepared) + converted = convert_pt2e(prepared) + return converted + + +def _assert_bias_dequantized(graph, conv_targets, dequant_targets): + """Assert every conv's bias flows through a dequantize node.""" + conv_count = 0 + for node in graph.nodes: + if node.target not in conv_targets: + continue + conv_count += 1 + bias = node.args[2] + assert bias is not None, "Bias should not be None after pass" + assert ( + bias.target in dequant_targets + ), f"Bias should be dequantized, got {bias.target}" + assert conv_count > 0, "Expected at least one convolution node" + + +# --- Direct aten pass tests --- + +_aten_conv_targets = ( + torch.ops.aten.convolution.default, + torch.ops.aten.conv2d.default, +) +_aten_dequant_targets = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +) + +_aten_direct_models = [ + pytest.param((ConvBnNoBias, True), id="conv2d_bn_per_channel"), + pytest.param((ConvBnNoBias, False), id="conv2d_bn_per_tensor"), + pytest.param((ConvBnReluNoBias, True), id="conv2d_bn_relu_per_channel"), + pytest.param((ConvBnReluNoBias, False), id="conv2d_bn_relu_per_tensor"), +] + + +@pytest.mark.parametrize("test_data", _aten_direct_models) +def test_aten_pass_direct(test_data) -> None: + """QuantizeFusedConvBnBiasAtenPass on GraphModule (get_attr nodes, no EP).""" + model_cls, per_channel = test_data + gm = _qat_prepare_convert(model_cls(), per_channel) + QuantizeFusedConvBnBiasAtenPass()(gm) + _assert_bias_dequantized(gm.graph, _aten_conv_targets, _aten_dequant_targets) + + +@pytest.mark.parametrize("test_data", _aten_direct_models) +def test_aten_pass_with_exported_program(test_data) -> None: + """QuantizeFusedConvBnBiasAtenPass on graph_module from EP (placeholder nodes).""" + model_cls, per_channel = test_data + model = model_cls() + gm = _qat_prepare_convert(model, per_channel) + ep = export(gm, model.get_inputs(), strict=True) + QuantizeFusedConvBnBiasAtenPass(ep)(ep.graph_module) + _assert_bias_dequantized( + ep.graph_module.graph, _aten_conv_targets, _aten_dequant_targets + ) + + +def test_aten_pass_idempotent() -> None: + """Running the pass twice doesn't break.""" + model = ConvBnNoBias() + gm = _qat_prepare_convert(model, per_channel=True) + QuantizeFusedConvBnBiasAtenPass()(gm) + QuantizeFusedConvBnBiasAtenPass()(gm) + _assert_bias_dequantized(gm.graph, _aten_conv_targets, _aten_dequant_targets) + + +def test_aten_pass_biasless_conv_no_bn() -> None: + """Pass skips biasless conv without following BatchNorm.""" + model = ConvNoBnNoBias() + gm = _qat_prepare_convert(model, per_channel=True) + QuantizeFusedConvBnBiasAtenPass()(gm)