Skip to content

Commit 8ed6e85

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for reflection_pad1/2d core ATen ops (#18963)
1 parent 66e4656 commit 8ed6e85

11 files changed

Lines changed: 193 additions & 15 deletions

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_log_variants import DecomposeLogVariants
2727
from .decompose_maxpool3d import DecomposeMaxPool3d
2828
from .decompose_minmaxdim import DecomposeMinMaxDim
29+
from .decompose_pad import DecomposePad
2930
from .decompose_reciprocal import DecomposeReciprocal
3031
from .decompose_remainder import DecomposeRemainder
3132
from .decompose_roll import DecomposeRoll
@@ -80,6 +81,7 @@
8081
DecomposeLogVariants,
8182
DecomposeMaxPool3d,
8283
DecomposeMinMaxDim,
84+
DecomposePad,
8385
DecomposeReciprocal,
8486
DecomposeRemainder,
8587
DecomposeRoll,
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class DecomposePad(ExportPass):
14+
"""
15+
Convert aten.pad.default with non-constant modes to specific pad ops.
16+
After torch.export, nn.ReflectionPad2d becomes aten.pad.default with mode='reflect'.
17+
This pass converts it to aten.reflection_pad2d.default which the QNN pad builder handles directly.
18+
19+
Supported:
20+
- mode='reflect', 4 padding values -> reflection_pad2d (QNN MIRROR_REFLECT, max rank 4).
21+
22+
Not supported by QNN (max rank 4 for non-constant schemes):
23+
- mode='reflect', 6 padding values (3d) -> reflection_pad3d (QNN MIRROR_REFLECT max rank is 4)
24+
- mode='replicate' -> QNN EDGE scheme produces incorrect results for FP32 inputs for replication_pad2d
25+
26+
Note: reflection_pad1d is handled by PyTorch's built-in decomposition of aten.pad.default (mode='reflect', 2 padding values)
27+
-> reflection_pad1d, combined with the skip decomp table entry for reflection_pad1d.
28+
"""
29+
30+
_PAD_TARGETS = {
31+
torch.ops.aten.pad.default,
32+
exir_ops.edge.aten.pad.default,
33+
}
34+
35+
_PAD_OPS = {
36+
("reflect", 4, False): torch.ops.aten.reflection_pad2d.default,
37+
("reflect", 4, True): exir_ops.edge.aten.reflection_pad2d.default,
38+
}
39+
40+
def call(self, graph_module: torch.fx.GraphModule):
41+
graph = graph_module.graph
42+
for node in list(graph.nodes):
43+
if node.op != "call_function" or node.target not in self._PAD_TARGETS:
44+
continue
45+
mode = node.args[2] if len(node.args) > 2 else "constant"
46+
47+
padding = node.args[1]
48+
is_edge = isinstance(node.target, EdgeOpOverload)
49+
target_op = self._PAD_OPS.get((mode, len(padding), is_edge))
50+
if target_op is None:
51+
continue
52+
53+
node.target = target_op
54+
node.args = (node.args[0], list(padding))
55+
56+
graph_module.recompile()
57+
return PassResult(graph_module, True)

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class LayoutTransform(ExportPass):
113113
exir_ops.edge.aten.neg.default,
114114
exir_ops.edge.aten.pow.Tensor_Scalar,
115115
exir_ops.edge.aten.prelu.default,
116+
exir_ops.edge.aten.reflection_pad1d.default,
117+
exir_ops.edge.aten.reflection_pad2d.default,
116118
exir_ops.edge.aten.repeat.default,
117119
exir_ops.edge.aten.relu.default,
118120
exir_ops.edge.aten.round.default,

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DecomposeLogVariants,
3232
DecomposeMaxPool3d,
3333
DecomposeMinMaxDim,
34+
DecomposePad,
3435
DecomposeReciprocal,
3536
DecomposeRemainder,
3637
DecomposeRoll,
@@ -107,6 +108,7 @@ def get_capture_program_passes():
107108
(DecomposeLogVariants, True),
108109
(DecomposeMaxPool3d, True),
109110
(DecomposeMinMaxDim, True),
111+
(DecomposePad, True),
110112
(DecomposeRemainder, True),
111113
(DecomposeTrunc, True),
112114
(ExpandBroadcastTensorShape, True),
@@ -227,6 +229,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
227229
self.add_pass(DecomposeBinaryAlpha())
228230
self.add_pass(DecomposeCDist())
229231
self.add_pass(DecomposeMaxPool3d(quantization_capture=True))
232+
self.add_pass(DecomposePad())
230233
self.add_pass(DecomposeScaledDotProductAttention())
231234
self.add_pass(DecomposeRoll())
232235
self.add_pass(DecomposeSilu())
@@ -254,6 +257,7 @@ def transform_for_export_pipeline(
254257
):
255258
self.add_pass(DecomposeBinaryAlpha())
256259
self.add_pass(DecomposeCDist())
260+
self.add_pass(DecomposePad())
257261
self.add_pass(DecomposeScaledDotProductAttention())
258262
self.add_pass(DecomposeRoll())
259263
self.add_pass(DecomposeThreshold())

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program():
7171
DecomposeLinalgVectorNorm,
7272
DecomposeLogVariants,
7373
DecomposeMaxPool3d,
74+
DecomposePad,
7475
DecomposeRemainder,
7576
DecomposeTrunc,
7677
ExpandBroadcastTensorShape,
@@ -102,6 +103,7 @@ def get_passes_dependency_for_capture_program():
102103
DecomposeLinalgVectorNorm: [RemoveRedundancy],
103104
DecomposeLogVariants: [RemoveRedundancy],
104105
DecomposeMaxPool3d: [RemoveRedundancy],
106+
DecomposePad: [RemoveRedundancy],
105107
DecomposeRemainder: [RemoveRedundancy],
106108
DecomposeTrunc: [RemoveRedundancy],
107109
ExpandBroadcastTensorShape: [FoldQDQ],

backends/qualcomm/builders/op_pad.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,17 @@
1818

1919
@register_node_visitor
2020
class Pad(NodeVisitor):
21-
target = ["aten.constant_pad_nd.default"]
21+
target = [
22+
"aten.constant_pad_nd.default",
23+
"aten.reflection_pad1d.default",
24+
"aten.reflection_pad2d.default",
25+
]
26+
27+
_SCHEME_MAP = {
28+
"aten.constant_pad_nd.default": OpPad.Scheme.CONSTANT,
29+
"aten.reflection_pad1d.default": OpPad.Scheme.MIRROR_REFLECT,
30+
"aten.reflection_pad2d.default": OpPad.Scheme.MIRROR_REFLECT,
31+
}
2232

2333
def __init__(self, *args) -> None:
2434
super().__init__(*args)
@@ -37,7 +47,6 @@ def define_node(
3747
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
3848
nodes_to_wrappers,
3949
)
40-
pad_input_tensors = [pad_inp_tensor_wrapper]
4150

4251
output_tensor = self.get_tensor(node, node)
4352
output_tensor_wrapper = self.define_tensor(
@@ -47,7 +56,6 @@ def define_node(
4756
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4857
nodes_to_wrappers,
4958
)
50-
pad_output_tensors = [output_tensor_wrapper]
5159

5260
pad_amount_shape = [input_tensor.dim(), 2]
5361
# pytorch padding start from the last index
@@ -62,28 +70,30 @@ def define_node(
6270

6371
if QCOM_AXIS_ORDER in node.meta:
6472
pad_amount = pad_amount[list(node.meta[QCOM_AXIS_ORDER])]
65-
pad_amount_val = node.args[2]
6673

74+
scheme = self._SCHEME_MAP[node.target.__name__]
6775
pad_op = PyQnnManager.PyQnnOpWrapper(
6876
node.name,
6977
QNN_OP_PACKAGE_NAME_QTI_AISW,
7078
OpPad.op_name,
7179
)
72-
pad_op.AddInputTensors(pad_input_tensors)
73-
pad_op.AddOutputTensors(pad_output_tensors)
80+
pad_op.AddInputTensors([pad_inp_tensor_wrapper])
81+
pad_op.AddOutputTensors([output_tensor_wrapper])
7482

75-
# For now, we only support constant (0) padding due to torch implementation
7683
pad_op.AddScalarParam(
7784
OpPad.param_scheme,
7885
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
79-
{QCOM_DATA: np.uint32(OpPad.Scheme.CONSTANT)},
86+
{QCOM_DATA: np.uint32(scheme)},
8087
)
8188

82-
pad_op.AddScalarParam(
83-
OpPad.param_pad_constant_value,
84-
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
85-
{QCOM_DATA: pad_amount_val},
86-
)
89+
# pad_constant_value is only applicable for CONSTANT scheme, meaning the PAD op
90+
if scheme == OpPad.Scheme.CONSTANT:
91+
pad_amount_val = node.args[2]
92+
pad_op.AddScalarParam(
93+
OpPad.param_pad_constant_value,
94+
QNN_TENSOR_TYPE_MAP[type(pad_amount_val)],
95+
{QCOM_DATA: pad_amount_val},
96+
)
8797

8898
pad_op.AddTensorParam(
8999
OpPad.param_pad_amount,

backends/qualcomm/partition/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
6464
torch.ops.aten.pixel_shuffle.default,
6565
torch.ops.aten.pixel_unshuffle.default,
6666
torch.ops.aten.prelu.default,
67+
torch.ops.aten.reflection_pad1d.default,
68+
torch.ops.aten.reflection_pad2d.default,
6769
torch.ops.aten.rms_norm.default,
6870
torch.ops.aten._safe_softmax.default,
6971
torch.ops.aten.stack.default,

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
11711171
annotate_binary(node, quantization_config)
11721172

11731173

1174-
@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name)
1174+
@register_annotator(
1175+
[
1176+
torch.ops.aten.pad.default,
1177+
torch.ops.aten.reflection_pad1d.default,
1178+
torch.ops.aten.reflection_pad2d.default,
1179+
],
1180+
QnnConstants.OpPad.op_name,
1181+
)
11751182
class Pad(GeneralOpDef):
11761183
pass
11771184

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,14 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
673673
annotate_binary(node, quantization_config)
674674

675675

676-
@register_annotator([torch.ops.aten.pad.default], QnnConstants.OpPad.op_name)
676+
@register_annotator(
677+
[
678+
torch.ops.aten.pad.default,
679+
torch.ops.aten.reflection_pad1d.default,
680+
torch.ops.aten.reflection_pad2d.default,
681+
],
682+
QnnConstants.OpPad.op_name,
683+
)
677684
class Pad(GeneralOpDef):
678685
pass
679686

backends/qualcomm/tests/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,33 @@ def forward(self, x):
18631863
return torch.reciprocal(x)
18641864

18651865

1866+
class ReflectionPad1d(torch.nn.Module):
1867+
def __init__(self):
1868+
super().__init__()
1869+
self.pad = torch.nn.ReflectionPad1d(2)
1870+
1871+
def forward(self, x):
1872+
return self.pad(x)
1873+
1874+
1875+
class ReflectionPad2d(torch.nn.Module):
1876+
def __init__(self):
1877+
super().__init__()
1878+
self.pad = torch.nn.ReflectionPad2d(2)
1879+
1880+
def forward(self, x):
1881+
return self.pad(x)
1882+
1883+
1884+
class ReflectionPad2dAsymmetric(torch.nn.Module):
1885+
def __init__(self):
1886+
super().__init__()
1887+
self.pad = torch.nn.ReflectionPad2d((1, 2, 3, 1))
1888+
1889+
def forward(self, x):
1890+
return self.pad(x)
1891+
1892+
18661893
class Relu(torch.nn.Module):
18671894
def __init__(self):
18681895
super().__init__()

0 commit comments

Comments
 (0)