1818
1919@register_node_visitor
2020class Pad (NodeVisitor ):
21- target = ["aten.constant_pad_nd.default" ]
21+ target = [
22+ "aten.constant_pad_nd.default" ,
23+ "aten.reflection_pad1d.default" ,
24+ "aten.reflection_pad2d.default" ,
25+ ]
26+
27+ _SCHEME_MAP = {
28+ "aten.constant_pad_nd.default" : OpPad .Scheme .CONSTANT ,
29+ "aten.reflection_pad1d.default" : OpPad .Scheme .MIRROR_REFLECT ,
30+ "aten.reflection_pad2d.default" : OpPad .Scheme .MIRROR_REFLECT ,
31+ }
2232
2333 def __init__ (self , * args ) -> None :
2434 super ().__init__ (* args )
@@ -37,7 +47,6 @@ def define_node(
3747 PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
3848 nodes_to_wrappers ,
3949 )
40- pad_input_tensors = [pad_inp_tensor_wrapper ]
4150
4251 output_tensor = self .get_tensor (node , node )
4352 output_tensor_wrapper = self .define_tensor (
@@ -47,7 +56,6 @@ def define_node(
4756 PyQnnManager .Qnn_TensorType_t .QNN_TENSOR_TYPE_NATIVE ,
4857 nodes_to_wrappers ,
4958 )
50- pad_output_tensors = [output_tensor_wrapper ]
5159
5260 pad_amount_shape = [input_tensor .dim (), 2 ]
5361 # pytorch padding start from the last index
@@ -62,28 +70,30 @@ def define_node(
6270
6371 if QCOM_AXIS_ORDER in node .meta :
6472 pad_amount = pad_amount [list (node .meta [QCOM_AXIS_ORDER ])]
65- pad_amount_val = node .args [2 ]
6673
74+ scheme = self ._SCHEME_MAP [node .target .__name__ ]
6775 pad_op = PyQnnManager .PyQnnOpWrapper (
6876 node .name ,
6977 QNN_OP_PACKAGE_NAME_QTI_AISW ,
7078 OpPad .op_name ,
7179 )
72- pad_op .AddInputTensors (pad_input_tensors )
73- pad_op .AddOutputTensors (pad_output_tensors )
80+ pad_op .AddInputTensors ([ pad_inp_tensor_wrapper ] )
81+ pad_op .AddOutputTensors ([ output_tensor_wrapper ] )
7482
75- # For now, we only support constant (0) padding due to torch implementation
7683 pad_op .AddScalarParam (
7784 OpPad .param_scheme ,
7885 PyQnnManager .Qnn_DataType_t .QNN_DATATYPE_UINT_32 ,
79- {QCOM_DATA : np .uint32 (OpPad . Scheme . CONSTANT )},
86+ {QCOM_DATA : np .uint32 (scheme )},
8087 )
8188
82- pad_op .AddScalarParam (
83- OpPad .param_pad_constant_value ,
84- QNN_TENSOR_TYPE_MAP [type (pad_amount_val )],
85- {QCOM_DATA : pad_amount_val },
86- )
89+ # pad_constant_value is only applicable for CONSTANT scheme, meaning the PAD op
90+ if scheme == OpPad .Scheme .CONSTANT :
91+ pad_amount_val = node .args [2 ]
92+ pad_op .AddScalarParam (
93+ OpPad .param_pad_constant_value ,
94+ QNN_TENSOR_TYPE_MAP [type (pad_amount_val )],
95+ {QCOM_DATA : pad_amount_val },
96+ )
8797
8898 pad_op .AddTensorParam (
8999 OpPad .param_pad_amount ,
0 commit comments