From 52cc9bf80d56d939f524cf55d4924e3c88f96090 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Fri, 10 Apr 2026 13:53:27 +0100 Subject: [PATCH] fix: Enable assembly kernel for int8 MAX pooling. Change-Id: I96cdf94dc29348a01b8f47aba1b6afd6d82fbcb9 Signed-off-by: Pablo Marquez Tello --- .../CpuPool2dAssemblyWrapperKernel.cpp | 13 +++++---- .../experimental/operators/CpuPool2d.cpp | 28 +++++++++++++++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp index 9a5438ab5a..498bddec07 100644 --- a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp +++ b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2025 Arm Limited. + * Copyright (c) 2021-2026 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -145,9 +145,10 @@ CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorIn if (src->data_type() == DataType::QASYMM8) { const bool has_padding = info.pad_stride_info.has_padding(); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !info.exclude_padding && has_padding, - "Assembly kernels do not support padding for QASYMM8 with same src/dst quantization info"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.pool_type == PoolingType::AVG && !info.exclude_padding && + has_padding, + "Assembly kernels do not support padded AVG pooling for QASYMM8 with " + "same src/dst quantization info"); } } } @@ -160,8 +161,8 @@ CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorIn // If dst is not configured, the quantization info are the same const bool has_padding = info.pad_stride_info.has_padding(); ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !info.exclude_padding && has_padding, - "Assembly kernels do not support padding for QASYMM8 with same src/dst quantization info"); + info.pool_type == PoolingType::AVG && !info.exclude_padding && has_padding, + "Assembly kernels do not support padded AVG pooling for QASYMM8 with same src/dst quantization info"); } } return Status{}; diff --git a/tests/validation/runtime/experimental/operators/CpuPool2d.cpp b/tests/validation/runtime/experimental/operators/CpuPool2d.cpp index 257029ab73..fd47aaf288 100644 --- a/tests/validation/runtime/experimental/operators/CpuPool2d.cpp +++ b/tests/validation/runtime/experimental/operators/CpuPool2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021, 2023-2025 Arm Limited. + * Copyright (c) 2017-2021, 2023-2026 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -72,6 +72,17 @@ const auto SmokePoolingDatasetQASYMM8 = combine(datasets::SmallNoneUnitShapes(), make("InputQuantInfo", {QuantizationInfo(0.2f, 10)}), make("OutputQuantInfo", {QuantizationInfo(0.2f, 10)})); +const auto SmokePoolingDatasetQASYMM8PaddedMax = + combine(make("Shape", {TensorShape(7U, 5U, 3U), TensorShape(8U, 7U, 5U)}), + make("PoolingType", {PoolingType::MAX}), + make("PoolingSize", {Size2D(3, 3)}), + make("PadStride", {PadStrideInfo(2, 2, 1, 1)}), + make("ExcludePadding", {false}), + make("DataType", DataType::QASYMM8), + make("DataLayout", {DataLayout::NHWC}), + make("InputQuantInfo", {QuantizationInfo(0.25f, 11)}), + make("OutputQuantInfo", {QuantizationInfo(0.25f, 11)})); + /** Tolerance for float operations */ constexpr AbsoluteTolerance tolerance_f32(0.000001f); constexpr AbsoluteTolerance tolerance_qasymm8( @@ -89,11 +100,12 @@ TEST_SUITE(CpuPool2d) TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid pad/size combination TensorInfo(TensorShape(15U, 13U, 5U), 1, DataType::F32), // Non-rectangular Global Pooling TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::F32), // Invalid output Global Pooling - TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::QASYMM8), // Invalid exclude_padding = false with quantized type, no actual padding and NHWC + TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::QASYMM8), // Quantized NHWC without padding remains valid TensorInfo(TensorShape(13U, 13U, 5U), 1, DataType::F32), TensorInfo(TensorShape(1U, 16U, 1U), 1, DataType::F32), TensorInfo(TensorShape(112, 112, 64,1), 1, DataType::F32, DataLayout::NHWC), // Mismatching number of channels TensorInfo(TensorShape(112, 112, 64,1), 1, DataType::F32, DataLayout::NHWC), // Mismatching width + TensorInfo(TensorShape(5U, 13U, 13U, 1U), 1, DataType::QASYMM8, DataLayout::NHWC), // Padded NHWC QASYMM8 MAX with matching qinfo }), make("OutputInfo",{ TensorInfo(TensorShape(25U, 11U, 2U), 1, DataType::F16), TensorInfo(TensorShape(25U, 10U, 2U), 1, DataType::F32), @@ -106,6 +118,7 @@ TEST_SUITE(CpuPool2d) TensorInfo(TensorShape(1U, 15U, 1U), 1, DataType::F32), TensorInfo(TensorShape(56, 56, 64,1), 1, DataType::F32, DataLayout::NHWC), TensorInfo(TensorShape(56, 51, 64,1), 1, DataType::F32, DataLayout::NHWC), + TensorInfo(TensorShape(5U, 7U, 7U, 1U), 1, DataType::QASYMM8, DataLayout::NHWC), }), make("PoolInfo", { PoolingLayerInfo(PoolingType::AVG, 3, DataLayout::NCHW, PadStrideInfo(1, 1, 0, 0)), PoolingLayerInfo(PoolingType::AVG, 3, DataLayout::NCHW, PadStrideInfo(1, 1, 0, 0)), @@ -118,8 +131,9 @@ TEST_SUITE(CpuPool2d) PoolingLayerInfo(PoolingType::MAX, 2, DataLayout::NHWC, PadStrideInfo(1, 1, 0, 0), false), PoolingLayerInfo(PoolingType::MAX,3,DataLayout::NHWC,PadStrideInfo(2,2,1,1)), PoolingLayerInfo(PoolingType::MAX,3,DataLayout::NHWC,PadStrideInfo(2,2,1,1)), + PoolingLayerInfo(PoolingType::MAX, 3, DataLayout::NHWC, PadStrideInfo(2, 2, 1, 1), false), }), - make("Expected", { false, false, false, false, true, false, true, false, false, false, false})), + make("Expected", { false, false, false, false, true, false, true, false, false, false, false, true})), input_info, output_info, pool_info, expected) { bool is_valid = bool(arm_compute::experimental::op::CpuPool2d::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), pool_info)); @@ -152,6 +166,14 @@ FIXTURE_DATA_TEST_CASE(SmokeQASYMM8, { validate(Accessor(_target), _reference, tolerance_qasymm8); } + +FIXTURE_DATA_TEST_CASE(SmokeQASYMM8PaddedMax, + CpuPool2dQuantizedFixture, + framework::DatasetMode::PRECOMMIT, + SmokePoolingDatasetQASYMM8PaddedMax) +{ + validate(Accessor(_target), _reference, tolerance_qasymm8); +} TEST_SUITE_END() // QASYMM8 TEST_SUITE_END() // CpuPool2d