Skip to content

Commit 3101e19

Browse files
mcremon-metameta-codesync[bot]
authored andcommitted
Split regular and depthwise quantized conv1d (#18807)
Summary: Pull Request resolved: #18807 As titled. This makes it more readable, easier and will improve code size. Differential Revision: D99509365
1 parent 3a72c4f commit 3101e19

8 files changed

Lines changed: 390 additions & 2 deletions

File tree

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@
399399
- arg_meta: null
400400
kernel_name: impl::generic::quantized_conv1d_nlc_per_tensor_out
401401

402+
- func: cadence::quantized_depthwise_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
403+
kernels:
404+
- arg_meta: null
405+
kernel_name: impl::generic::quantized_depthwise_conv1d_ncl_per_tensor_out
406+
407+
- func: cadence::quantized_depthwise_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
408+
kernels:
409+
- arg_meta: null
410+
kernel_name: impl::generic::quantized_depthwise_conv1d_nlc_per_tensor_out
411+
402412
- func: cadence::quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
403413
kernels:
404414
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,18 @@ def register_fake(
262262
lib.define(
263263
"quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
264264
)
265+
lib.define(
266+
"quantized_depthwise_conv1d_ncl.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
267+
)
268+
lib.define(
269+
"quantized_depthwise_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
270+
)
271+
lib.define(
272+
"quantized_depthwise_conv1d_nlc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
273+
)
274+
lib.define(
275+
"quantized_depthwise_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
276+
)
265277
lib.define(
266278
"quantized_conv2d_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
267279
)
@@ -1256,6 +1268,78 @@ def quantized_conv1d_nlc_per_tensor_meta(
12561268
return input.new_empty(output_size, dtype=input.dtype)
12571269

12581270

1271+
@register_fake("cadence::quantized_depthwise_conv1d_ncl.per_tensor")
1272+
def quantized_depthwise_conv1d_ncl_per_tensor_meta(
1273+
input: torch.Tensor,
1274+
weight: torch.Tensor,
1275+
bias: torch.Tensor,
1276+
stride: Tuple[int],
1277+
padding: Tuple[int],
1278+
dilation: Tuple[int],
1279+
groups: int,
1280+
in_zero_point: int,
1281+
weight_zero_point: int,
1282+
bias_scale: float,
1283+
output_scale: float,
1284+
output_zero_point: int,
1285+
out_multiplier: int,
1286+
out_shift: int,
1287+
) -> torch.Tensor:
1288+
# NCL format: input is [N, C, L], weight is [OC, IC/groups, K]
1289+
out_channels, _, kernel_size = weight.shape
1290+
1291+
in_size = input.shape
1292+
assert len(in_size) == 3
1293+
1294+
output_size = get_conv1d_output_size(
1295+
in_size,
1296+
out_channels,
1297+
stride[-1],
1298+
padding[-1],
1299+
dilation[-1],
1300+
kernel_size,
1301+
False,
1302+
)
1303+
1304+
return input.new_empty(output_size, dtype=input.dtype)
1305+
1306+
1307+
@register_fake("cadence::quantized_depthwise_conv1d_nlc.per_tensor")
1308+
def quantized_depthwise_conv1d_nlc_per_tensor_meta(
1309+
input: torch.Tensor,
1310+
weight: torch.Tensor,
1311+
bias: torch.Tensor,
1312+
stride: Tuple[int],
1313+
padding: Tuple[int],
1314+
dilation: Tuple[int],
1315+
groups: int,
1316+
in_zero_point: int,
1317+
weight_zero_point: int,
1318+
bias_scale: float,
1319+
output_scale: float,
1320+
output_zero_point: int,
1321+
out_multiplier: int,
1322+
out_shift: int,
1323+
) -> torch.Tensor:
1324+
# NLC format: input is [N, L, C], weight is [OC, K, IC/groups]
1325+
out_channels, kernel_size, _ = weight.shape
1326+
1327+
in_size = input.shape
1328+
assert len(in_size) == 3
1329+
1330+
output_size = get_conv1d_output_size(
1331+
in_size,
1332+
out_channels,
1333+
stride[-1],
1334+
padding[-1],
1335+
dilation[-1],
1336+
kernel_size,
1337+
True,
1338+
)
1339+
1340+
return input.new_empty(output_size, dtype=input.dtype)
1341+
1342+
12591343
@register_fake("cadence::quantized_conv2d_nchw")
12601344
def quantized_conv2d_nchw_meta(
12611345
input: torch.Tensor,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
find_sequential_partitions_aten,
4343
quantize_tensor_multiplier,
4444
)
45+
from executorch.backends.cadence.aot.utils import is_depthwise_conv
4546
from executorch.exir.pass_base import ExportPass
4647
from torch import fx
4748
from torch.fx import GraphModule
@@ -758,8 +759,23 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
758759
op_node,
759760
)
760761

762+
# Determine the replacement op, routing depthwise conv1d
763+
# to the dedicated depthwise operator.
764+
replacement_op = pattern.replacement_op()
765+
if (
766+
replacement_op
767+
== torch.ops.cadence.quantized_conv1d_ncl.per_tensor
768+
):
769+
groups = kwargs.get("groups", 1)
770+
# NCL format: input shape is [N, C, L]
771+
in_channels = args[0].meta["val"].shape[1]
772+
if is_depthwise_conv(groups, in_channels):
773+
replacement_op = (
774+
torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor
775+
)
776+
761777
fused = graph_module.graph.call_function(
762-
pattern.replacement_op(),
778+
replacement_op,
763779
args,
764780
kwargs,
765781
)

backends/cadence/aot/ref_implementations.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,122 @@ def quantized_conv1d_nlc(
11361136
)
11371137

11381138

1139+
@impl_tracked(m, "quantized_depthwise_conv1d_ncl.per_tensor")
1140+
def quantized_depthwise_conv1d_ncl_per_tensor(
1141+
input_tensor: torch.Tensor,
1142+
weight: torch.Tensor,
1143+
bias: torch.Tensor,
1144+
stride: tuple[int],
1145+
padding: tuple[int],
1146+
dilation: tuple[int],
1147+
groups: int,
1148+
in_zero_point: int,
1149+
weight_zero_point: int,
1150+
bias_scale: float,
1151+
output_scale: float,
1152+
output_zero_point: int,
1153+
out_multiplier: int,
1154+
out_shift: int,
1155+
) -> torch.Tensor:
1156+
"""
1157+
Quantized depthwise 1D convolution in NCL (channels-first) format.
1158+
1159+
This op only handles depthwise convolutions (groups == in_channels, groups > 1).
1160+
Regular convolutions must use quantized_conv1d_ncl instead.
1161+
1162+
Args:
1163+
- input_tensor (Tensor): [N, C, L] format
1164+
- weight (Tensor): [OC, 1, K] format (IC/groups == 1 for depthwise)
1165+
- bias (Tensor): [OC]
1166+
- stride, padding, dilation, groups: convolution parameters
1167+
- in_zero_point, weight_zero_point, bias_scale: quantization params
1168+
- output_scale, output_zero_point: output quantization params
1169+
- out_multiplier, out_shift: unused
1170+
"""
1171+
assert is_depthwise_conv(
1172+
groups, input_tensor.shape[1]
1173+
), f"quantized_depthwise_conv1d_ncl requires depthwise conv (groups == in_channels), got groups={groups}, in_channels={input_tensor.shape[1]}"
1174+
1175+
return quantized_conv_per_tensor(
1176+
input_tensor,
1177+
weight,
1178+
bias,
1179+
stride,
1180+
padding,
1181+
dilation,
1182+
groups,
1183+
in_zero_point,
1184+
weight_zero_point,
1185+
bias_scale,
1186+
output_scale,
1187+
output_zero_point,
1188+
out_multiplier,
1189+
out_shift,
1190+
)
1191+
1192+
1193+
@impl_tracked(m, "quantized_depthwise_conv1d_nlc.per_tensor")
1194+
def quantized_depthwise_conv1d_nlc_per_tensor(
1195+
input_tensor: torch.Tensor,
1196+
weight: torch.Tensor,
1197+
bias: torch.Tensor,
1198+
stride: tuple[int],
1199+
padding: tuple[int],
1200+
dilation: tuple[int],
1201+
groups: int,
1202+
in_zero_point: int,
1203+
weight_zero_point: int,
1204+
bias_scale: float,
1205+
output_scale: float,
1206+
output_zero_point: int,
1207+
out_multiplier: int,
1208+
out_shift: int,
1209+
) -> torch.Tensor:
1210+
"""
1211+
Quantized depthwise 1D convolution in NLC (channels-last) format.
1212+
1213+
This op only handles depthwise convolutions (groups == in_channels, groups > 1).
1214+
Regular convolutions must use quantized_conv1d_nlc instead.
1215+
1216+
Args:
1217+
- input_tensor (Tensor): [N, L, C] format
1218+
- weight (Tensor): [OC, K, 1] format (IC/groups == 1 for depthwise)
1219+
- bias (Tensor): [OC]
1220+
- stride, padding, dilation, groups: convolution parameters
1221+
- in_zero_point, weight_zero_point, bias_scale: quantization params
1222+
- output_scale, output_zero_point: output quantization params
1223+
- out_multiplier, out_shift: unused
1224+
"""
1225+
assert is_depthwise_conv(
1226+
groups, input_tensor.shape[-1]
1227+
), f"quantized_depthwise_conv1d_nlc requires depthwise conv (groups == in_channels), got groups={groups}, in_channels={input_tensor.shape[-1]}"
1228+
1229+
# Convert NLC to NCL for processing
1230+
input_ncl = input_tensor.permute(0, 2, 1).contiguous()
1231+
# Convert weight from [OC, K, IC/groups] to [OC, IC/groups, K]
1232+
weight_ncl = weight.permute(0, 2, 1).contiguous()
1233+
1234+
result_ncl = quantized_conv_per_tensor(
1235+
input_ncl,
1236+
weight_ncl,
1237+
bias,
1238+
stride,
1239+
padding,
1240+
dilation,
1241+
groups,
1242+
in_zero_point,
1243+
weight_zero_point,
1244+
bias_scale,
1245+
output_scale,
1246+
output_zero_point,
1247+
out_multiplier,
1248+
out_shift,
1249+
)
1250+
1251+
# Convert result back to NLC format
1252+
return result_ncl.permute(0, 2, 1).contiguous()
1253+
1254+
11391255
@impl_tracked(m, "quantized_conv2d_nchw")
11401256
def quantized_conv2d_nchw(
11411257
input_tensor: torch.Tensor,

backends/cadence/aot/replace_ops.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ def targets(self) -> list[EdgeOpOverload]:
10301030
exir_ops.edge.cadence.conv2d.default,
10311031
exir_ops.edge.cadence.conv3d.default,
10321032
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor,
1033+
exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor,
10331034
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
10341035
]
10351036

@@ -1114,6 +1115,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11141115
assert isinstance(node.target, EdgeOpOverload)
11151116
quantized_op = node.target in {
11161117
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor,
1118+
exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor,
11171119
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
11181120
}
11191121

@@ -1132,7 +1134,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11321134
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
11331135
else:
11341136
assert len(input_shape) == 3
1135-
new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor
1137+
if (
1138+
node.target
1139+
== exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor
1140+
):
1141+
new_op = (
1142+
exir_ops.edge.cadence.quantized_depthwise_conv1d_nlc.per_tensor
1143+
)
1144+
else:
1145+
new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor
11361146
else:
11371147
new_op = node.target
11381148

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/generic/operators/op_quantized_conv1d_ncl.h>
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
13+
14+
namespace impl {
15+
namespace generic {
16+
namespace native {
17+
18+
using ::executorch::aten::IntArrayRef;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::runtime::KernelRuntimeContext;
21+
22+
// Depthwise conv1d NCL: delegates to the regular conv1d NCL implementation
23+
// which already handles grouped (depthwise) convolution correctly via
24+
// ocpg/icpg decomposition. This operator exists as a separate entry point
25+
// so that depthwise and regular conv1d are cleanly separated at the graph
26+
// level, enabling independent optimization.
27+
::executorch::aten::Tensor& quantized_depthwise_conv1d_ncl_per_tensor_out(
28+
KernelRuntimeContext& ctx,
29+
const Tensor& input,
30+
const Tensor& weight,
31+
const Tensor& bias,
32+
IntArrayRef stride,
33+
IntArrayRef padding,
34+
IntArrayRef dilation,
35+
int64_t groups,
36+
int64_t input_zero_point,
37+
int64_t weight_zero_point,
38+
double bias_scale,
39+
double output_scale,
40+
int64_t output_zero_point,
41+
int64_t out_multiplier,
42+
int64_t out_shift,
43+
Tensor& out) {
44+
return quantized_conv1d_ncl_per_tensor_out(
45+
ctx,
46+
input,
47+
weight,
48+
bias,
49+
stride,
50+
padding,
51+
dilation,
52+
groups,
53+
input_zero_point,
54+
weight_zero_point,
55+
bias_scale,
56+
output_scale,
57+
output_zero_point,
58+
out_multiplier,
59+
out_shift,
60+
out);
61+
}
62+
63+
} // namespace native
64+
} // namespace generic
65+
} // namespace impl

0 commit comments

Comments
 (0)