diff --git a/backends/nxp/BUCK b/backends/nxp/BUCK index a701a661b61..03fedf7ec7a 100644 --- a/backends/nxp/BUCK +++ b/backends/nxp/BUCK @@ -56,6 +56,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ ":aten_passes", "//caffe2:torch", + "//executorch/backends/transforms:quantize_fused_convbn_bias_pass", "//pytorch/ao:torchao", # @manual ], ) diff --git a/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py index bf9b3ad86ac..061862116e6 100644 --- a/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py +++ b/backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/remove_simulated_linear_bn_fusion_qat_pass.py @@ -25,6 +25,41 @@ _is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape) _is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like) +_CONV_TARGETS = { + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, +} + + +def _feeds_into_linear(node: Node) -> bool: + """ + BFS from node to check if it eventually feeds into a linear op (not conv). + This is required because: + - Linear-BN fusion (added by AddSimulatedLinearBatchNormFusionQATPass, NXP-specific) + - Conv-BN QAT fusion (added by TorchAO's _fuse_conv_bn_qat inside prepare_qat_pt2e) + are structurally identical. Without this check, we would incorrectly remove + Conv-BN scale factor chains, breaking Conv-BN QAT fusion when TorchAO's _fold_conv_bn_qat + is called during convert_pt2e. + """ + visited = set() + queue = list(node.users.keys()) + while queue: + n = queue.pop(0) + if n in visited: + continue + visited.add(n) + if n.op == "call_function": + if n.target == torch.ops.aten.linear.default: + return True + if n.target in _CONV_TARGETS: + return False + queue.extend(n.users.keys()) + return True + def _is_denorm_pattern(node: Node) -> bool: if not _is_div(node): @@ -56,6 +91,10 @@ def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule): for match in matches: last_pattern_node = match.anchors[0] last_matched_subgraph_node = match.nodes_map[last_pattern_node] + + if not _feeds_into_linear(last_matched_subgraph_node): + continue + weight = match.placeholder_nodes[0] last_matched_subgraph_node.replace_all_uses_with(weight) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index aaa6afc4626..da2448fb773 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -20,6 +20,9 @@ AddSimulatedLinearBatchNormFusionQATPass, RemoveSimulatedLinearBatchNormFusionQATPass, ) +from executorch.backends.transforms.quantize_fused_convbn_bias_pass import ( + QuantizeFusedConvBnBiasAtenPass, +) from torch import fx from torch._ops import OpOverload from torch.export import ExportedProgram @@ -205,4 +208,6 @@ def calibrate_and_quantize( m = convert_pt2e(m) + m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module + return m diff --git a/backends/nxp/tests/BUCK b/backends/nxp/tests/BUCK index 66ec9ba1f9b..21a740909e1 100644 --- a/backends/nxp/tests/BUCK +++ b/backends/nxp/tests/BUCK @@ -54,6 +54,16 @@ fbcode_target(_kind = python_pytest, ] ) +fbcode_target(_kind = runtime.python_library, + name = "use_qat", + srcs = [ + "use_qat.py", + ], + deps = [ + "fbsource//third-party/pypi/pytest:pytest", + ], +) + fbcode_target(_kind = python_pytest, name = "test_batch_norm_fusion", srcs = [ @@ -68,3 +78,34 @@ fbcode_target(_kind = python_pytest, "fbsource//third-party/pypi/numpy:numpy", ], ) + +fbcode_target(_kind = python_pytest, + name = "test_qdq_clustering_conv", + srcs = [ + "test_qdq_clustering_conv.py", + ], + deps = [ + ":executorch_pipeline", + ":models", + ], +) + +fbcode_target(_kind = python_pytest, + name = "test_integration", + srcs = [ + "test_integration.py", + ], + preload_deps = [ + "//executorch/kernels/quantized:custom_ops_generated_lib", + ], + deps = [ + ":executorch_pipeline", + ":models", + ":use_qat", + "//executorch/devtools/backend_debug:delegation_info", + "//executorch/extension/pybindings:portable_lib", + "//executorch/examples/nxp/experimental/cifar_net:cifar_net", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "//executorch/kernels/quantized:quantized_ops_lib", + ], +) diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index 9a879534d5f..b0294bfa702 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -22,7 +22,12 @@ neutron_target_spec, to_quantized_edge_program, ) -from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck +from executorch.backends.nxp.tests.executors import ( + graph_contains_any_of_ops, + OverrideTargetSupportCheck, +) + +from executorch.backends.nxp.tests.models import ConvBNModule from torch import nn @@ -229,3 +234,28 @@ def unsupported_target(*_): # Accept all input arguments and return `False`. node.op == "call_function" and "batch_norm" in node.target.__name__ for node in nodes ) + + +@pytest.mark.parametrize( + "conv_module", + ["conv2d"], +) +def test_biasless_convbn_fusion_qat( + conv_module, +): + if conv_module.startswith("conv1d"): + input_shape = (1, 3, 32) + elif conv_module.startswith("conv2d"): + input_shape = (1, 3, 32, 32) + else: # conv3d + input_shape = (1, 3, 32, 32, 32) + + model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True) + + edge_program = to_quantized_edge_program( + model, input_shape, use_qat=True, use_neutron_for_format_conversion=False + ).exported_program() + + assert graph_contains_any_of_ops( + edge_program.graph, [torch.ops.higher_order.executorch_call_delegate] + ) diff --git a/backends/nxp/tests/test_integration.py b/backends/nxp/tests/test_integration.py index fe157b44c48..e8d2a6faf26 100644 --- a/backends/nxp/tests/test_integration.py +++ b/backends/nxp/tests/test_integration.py @@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat): delegation_info = get_delegation_info(program.graph_module) assert delegation_info.num_delegated_subgraphs == 1 assert delegation_info.num_non_delegated_nodes == 11 - assert delegation_info.num_delegated_nodes == 13 + assert delegation_info.num_delegated_nodes == 14 for node in program.graph.nodes: # Make sure Convolution and AddMM are delegated diff --git a/backends/nxp/tests/test_qdq_clustering_conv.py b/backends/nxp/tests/test_qdq_clustering_conv.py index ffae931dbb4..21e330a48ae 100644 --- a/backends/nxp/tests/test_qdq_clustering_conv.py +++ b/backends/nxp/tests/test_qdq_clustering_conv.py @@ -16,16 +16,18 @@ def test_conv2d_partitioner(): lowered_module = edge_program.exported_program().graph_module.lowered_module_0 nodes = list(lowered_module.original_module.graph.nodes) - assert len(nodes) == 9 + assert len(nodes) == 13 - q_x_node = nodes[3] - dq_w_node = nodes[4] - dq_x_node = nodes[5] - conv_node = nodes[6] - q_y_node = nodes[7] + q_x_node = nodes[6] + dq_w_node = nodes[7] + dq_x_node = nodes[8] + dq_bias_node = nodes[9] + conv_node = nodes[10] + q_y_node = nodes[11] assert "cluster" not in q_x_node.meta assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster" assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster" + assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster" assert conv_node.meta["cluster"] == "aten_convolution_default_cluster" assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster" diff --git a/examples/nxp/experimental/cifar_net/BUCK b/examples/nxp/experimental/cifar_net/BUCK new file mode 100644 index 00000000000..0c78176ef93 --- /dev/null +++ b/examples/nxp/experimental/cifar_net/BUCK @@ -0,0 +1,18 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "cifar_net", + srcs = [ + "cifar_net.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/examples/models:model_base", + "fbsource//third-party/pypi/numpy:numpy", + "//pytorch/vision:torchvision", # @manual + ], +)