Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ fbcode_target(_kind = runtime.python_library,
deps = [
":aten_passes",
"//caffe2:torch",
"//executorch/backends/transforms:quantize_fused_convbn_bias_pass",
"//pytorch/ao:torchao", # @manual
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,41 @@
_is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape)
_is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like)

_CONV_TARGETS = {
torch.ops.aten.conv1d.default,
torch.ops.aten.conv1d.padding,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv2d.padding,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d.input,
}


def _feeds_into_linear(node: Node) -> bool:
"""
BFS from node to check if it eventually feeds into a linear op (not conv).
This is required because:
- Linear-BN fusion (added by AddSimulatedLinearBatchNormFusionQATPass, NXP-specific)
- Conv-BN QAT fusion (added by TorchAO's _fuse_conv_bn_qat inside prepare_qat_pt2e)
are structurally identical. Without this check, we would incorrectly remove
Conv-BN scale factor chains, breaking Conv-BN QAT fusion when TorchAO's _fold_conv_bn_qat
is called during convert_pt2e.
"""
visited = set()
queue = list(node.users.keys())
while queue:
n = queue.pop(0)
if n in visited:
continue
visited.add(n)
if n.op == "call_function":
if n.target == torch.ops.aten.linear.default:
return True
if n.target in _CONV_TARGETS:
return False
queue.extend(n.users.keys())
return True


def _is_denorm_pattern(node: Node) -> bool:
if not _is_div(node):
Expand Down Expand Up @@ -56,6 +91,10 @@ def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule):
for match in matches:
last_pattern_node = match.anchors[0]
last_matched_subgraph_node = match.nodes_map[last_pattern_node]

if not _feeds_into_linear(last_matched_subgraph_node):
continue

weight = match.placeholder_nodes[0]

last_matched_subgraph_node.replace_all_uses_with(weight)
Expand Down
5 changes: 5 additions & 0 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
AddSimulatedLinearBatchNormFusionQATPass,
RemoveSimulatedLinearBatchNormFusionQATPass,
)
from executorch.backends.transforms.quantize_fused_convbn_bias_pass import (
QuantizeFusedConvBnBiasAtenPass,
)
from torch import fx
from torch._ops import OpOverload
from torch.export import ExportedProgram
Expand Down Expand Up @@ -205,4 +208,6 @@ def calibrate_and_quantize(

m = convert_pt2e(m)

m = QuantizeFusedConvBnBiasAtenPass(default_zero_bias=True)(m).graph_module

return m
41 changes: 41 additions & 0 deletions backends/nxp/tests/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ fbcode_target(_kind = python_pytest,
]
)

fbcode_target(_kind = runtime.python_library,
name = "use_qat",
srcs = [
"use_qat.py",
],
deps = [
"fbsource//third-party/pypi/pytest:pytest",
],
)

fbcode_target(_kind = python_pytest,
name = "test_batch_norm_fusion",
srcs = [
Expand All @@ -68,3 +78,34 @@ fbcode_target(_kind = python_pytest,
"fbsource//third-party/pypi/numpy:numpy",
],
)

fbcode_target(_kind = python_pytest,
name = "test_qdq_clustering_conv",
srcs = [
"test_qdq_clustering_conv.py",
],
deps = [
":executorch_pipeline",
":models",
],
)

fbcode_target(_kind = python_pytest,
name = "test_integration",
srcs = [
"test_integration.py",
],
preload_deps = [
"//executorch/kernels/quantized:custom_ops_generated_lib",
],
deps = [
":executorch_pipeline",
":models",
":use_qat",
"//executorch/devtools/backend_debug:delegation_info",
"//executorch/extension/pybindings:portable_lib",
"//executorch/examples/nxp/experimental/cifar_net:cifar_net",
"//executorch/kernels/quantized:custom_ops_generated_lib",
"//executorch/kernels/quantized:quantized_ops_lib",
],
)
32 changes: 31 additions & 1 deletion backends/nxp/tests/test_batch_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
neutron_target_spec,
to_quantized_edge_program,
)
from executorch.backends.nxp.tests.executors import OverrideTargetSupportCheck
from executorch.backends.nxp.tests.executors import (
graph_contains_any_of_ops,
OverrideTargetSupportCheck,
)

from executorch.backends.nxp.tests.models import ConvBNModule
from torch import nn


Expand Down Expand Up @@ -229,3 +234,28 @@ def unsupported_target(*_): # Accept all input arguments and return `False`.
node.op == "call_function" and "batch_norm" in node.target.__name__
for node in nodes
)


@pytest.mark.parametrize(
"conv_module",
["conv2d"],
)
def test_biasless_convbn_fusion_qat(
conv_module,
):
if conv_module.startswith("conv1d"):
input_shape = (1, 3, 32)
elif conv_module.startswith("conv2d"):
input_shape = (1, 3, 32, 32)
else: # conv3d
input_shape = (1, 3, 32, 32, 32)

model = ConvBNModule(conv_module, conv_bias=False, bn_affine=True)

edge_program = to_quantized_edge_program(
model, input_shape, use_qat=True, use_neutron_for_format_conversion=False
).exported_program()

assert graph_contains_any_of_ops(
edge_program.graph, [torch.ops.higher_order.executorch_call_delegate]
)
2 changes: 1 addition & 1 deletion backends/nxp/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_conv_fc_softmax__to_executorch_program(use_qat):
delegation_info = get_delegation_info(program.graph_module)
assert delegation_info.num_delegated_subgraphs == 1
assert delegation_info.num_non_delegated_nodes == 11
assert delegation_info.num_delegated_nodes == 13
assert delegation_info.num_delegated_nodes == 14

for node in program.graph.nodes:
# Make sure Convolution and AddMM are delegated
Expand Down
14 changes: 8 additions & 6 deletions backends/nxp/tests/test_qdq_clustering_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ def test_conv2d_partitioner():
lowered_module = edge_program.exported_program().graph_module.lowered_module_0
nodes = list(lowered_module.original_module.graph.nodes)

assert len(nodes) == 9
assert len(nodes) == 13

q_x_node = nodes[3]
dq_w_node = nodes[4]
dq_x_node = nodes[5]
conv_node = nodes[6]
q_y_node = nodes[7]
q_x_node = nodes[6]
dq_w_node = nodes[7]
dq_x_node = nodes[8]
dq_bias_node = nodes[9]
conv_node = nodes[10]
q_y_node = nodes[11]

assert "cluster" not in q_x_node.meta
assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster"
assert dq_x_node.meta["cluster"] == "aten_convolution_default_cluster"
assert dq_bias_node.meta["cluster"] == "aten_convolution_default_cluster"
assert conv_node.meta["cluster"] == "aten_convolution_default_cluster"
assert q_y_node.meta["cluster"] == "aten_convolution_default_cluster"
18 changes: 18 additions & 0 deletions examples/nxp/experimental/cifar_net/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

fbcode_target(_kind = runtime.python_library,
name = "cifar_net",
srcs = [
"cifar_net.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/examples/models:model_base",
"fbsource//third-party/pypi/numpy:numpy",
"//pytorch/vision:torchvision", # @manual
],
)
Loading