|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import mxnet as mx |
| 19 | +import numpy as np |
| 20 | +import unittest |
| 21 | +import pytest |
| 22 | +import ctypes |
| 23 | +import copy |
| 24 | + |
| 25 | + |
| 26 | +from mxnet.contrib import quantization |
| 27 | +from mxnet.gluon import nn |
| 28 | +from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err |
| 29 | + |
| 30 | +OP_NAME='op_name' |
| 31 | +QUANTIZED_OP_NAME='quantized_op_name' |
| 32 | +SG_PASS_NAME='MKLDNN' |
| 33 | +QUANTIZE_SG_PASS_NAME='MKLDNN_QUANTIZE' |
| 34 | +config = { |
| 35 | + 'conv': { |
| 36 | + OP_NAME: 'sg_mkldnn_conv', |
| 37 | + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv' |
| 38 | + }, |
| 39 | + 'fc': { |
| 40 | + OP_NAME: 'sg_mkldnn_fully_connected', |
| 41 | + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected' |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] |
| 46 | + |
| 47 | +# Helpers |
| 48 | +class RELU6(nn.HybridBlock): |
| 49 | + """Relu6 used in MobileNetV2.""" |
| 50 | + |
| 51 | + def __init__(self, **kwargs): |
| 52 | + super(RELU6, self).__init__(**kwargs) |
| 53 | + |
| 54 | + def hybrid_forward(self, F, x): |
| 55 | + return F.clip(x, 0, 6, name="relu6") |
| 56 | + |
| 57 | +class TailNegBlock(nn.HybridBlock): |
| 58 | + def __init__(self, **kwargs): |
| 59 | + super(TailNegBlock, self).__init__(**kwargs) |
| 60 | + self.fc1 = nn.Dense(10, flatten=True) |
| 61 | + self.fc2 = nn.Dense(10, flatten=True) |
| 62 | + |
| 63 | + def hybrid_forward(self, F, x1, x2): |
| 64 | + out_fc1 = self.fc1(x1) |
| 65 | + out_fc2 = self.fc2(x2) |
| 66 | + out = F.concat(out_fc1, out_fc2) |
| 67 | + out = F.softmax(out) |
| 68 | + return out |
| 69 | + |
| 70 | +class CustomNormalInit(mx.init.Initializer): |
| 71 | + """Initializes weights with random values sampled from a normal distribution |
| 72 | + with a custom mean and standard deviation of `sigma`. |
| 73 | + """ |
| 74 | + def __init__(self, mean=0, sigma=0.01): |
| 75 | + super(CustomNormalInit, self).__init__(mean=mean, sigma=sigma) |
| 76 | + self.mean = mean |
| 77 | + self.sigma = sigma |
| 78 | + |
| 79 | + def _init_weight(self, _, arr): |
| 80 | + mx.random.normal(self.mean, self.sigma, arr.shape, dtype=arr.dtype, out=arr) |
| 81 | + |
| 82 | + |
| 83 | +def check_qsym_calibrated(qsym, out_type, name='conv'): |
| 84 | + quantized_op_name = 'quantized_' + name |
| 85 | + assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 |
| 86 | + for k, v in qsym.attr_dict().items(): |
| 87 | + if k.find('_quantize') != -1: |
| 88 | + assert v['out_type'] == out_type |
| 89 | + if k.find(quantized_op_name) != -1: |
| 90 | + if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v: |
| 91 | + continue |
| 92 | + assert 'min_calib_range' in v |
| 93 | + assert 'max_calib_range' in v |
| 94 | + |
| 95 | +def check_qsym_scale_align(qsym): |
| 96 | + assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 |
| 97 | + init = False |
| 98 | + for k, v in qsym.attr_dict().items(): |
| 99 | + if k.find('quantized_sg_mkldnn_conv') != -1: |
| 100 | + assert 'min_calib_range' in v |
| 101 | + assert 'max_calib_range' in v |
| 102 | + if not init: |
| 103 | + min_calib_range = v['min_calib_range'] |
| 104 | + max_calib_range = v['max_calib_range'] |
| 105 | + init = True |
| 106 | + else: |
| 107 | + assert min_calib_range == v['min_calib_range'] |
| 108 | + assert max_calib_range == v['max_calib_range'] |
| 109 | + |
| 110 | + |
| 111 | +def check_quantize(net_original, data_shape, out_type, name='conv', |
| 112 | + check_calibration=True, check_scale_align=False): |
| 113 | + quantize_granularity_list = ['tensor-wise'] |
| 114 | + if name == 'fc': |
| 115 | + quantize_granularity_list += ['channel-wise'] |
| 116 | + |
| 117 | + if name in config: |
| 118 | + name = config[name][OP_NAME] |
| 119 | + |
| 120 | + net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True) |
| 121 | + min_value = -1 if out_type != 'uint8' else 0 |
| 122 | + data = mx.random.uniform(min_value, 1.0, shape=data_shape, dtype='float32', ctx=mx.current_context()) |
| 123 | + |
| 124 | + outputs = net_original(data) |
| 125 | + for output in outputs: |
| 126 | + output.wait_to_read() |
| 127 | + ref_out = outputs |
| 128 | + |
| 129 | + calib_data = mx.gluon.data.DataLoader(data, batch_size=1) |
| 130 | + for quantize_granularity in quantize_granularity_list: |
| 131 | + qnet = quantization.quantize_net(net_original, |
| 132 | + ctx=mx.current_context(), |
| 133 | + exclude_layers=None, |
| 134 | + exclude_operators=None, |
| 135 | + quantized_dtype=out_type, |
| 136 | + calib_mode='naive', |
| 137 | + calib_data=calib_data, |
| 138 | + num_calib_batches=1, |
| 139 | + quantize_mode='full', |
| 140 | + quantize_granularity=quantize_granularity) |
| 141 | + qsym, _ = qnet.export(None) |
| 142 | + if check_calibration: |
| 143 | + check_qsym_calibrated(qsym, out_type, name=name) |
| 144 | + if check_scale_align: |
| 145 | + check_qsym_scale_align(qsym) |
| 146 | + |
| 147 | + quantized_out = qnet(data) |
| 148 | + for i in range(len(ref_out)): |
| 149 | + min_range = mx.nd.min(ref_out[i]).asscalar() |
| 150 | + max_range = mx.nd.max(ref_out[i]).asscalar() |
| 151 | + atol = 0.1 * max(abs(min_range), abs(max_range)) |
| 152 | + assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2) |
| 153 | + |
| 154 | + |
| 155 | +def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True, |
| 156 | + out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True): |
| 157 | + net_original.initialize() |
| 158 | + net_original.hybridize(static_alloc=False, static_shape=False) |
| 159 | + data = mx.random.uniform(shape=data_shape, dtype='float32', ctx=mx.current_context()) |
| 160 | + net_original(data) |
| 161 | + net_fusion = copy.copy(net_original) |
| 162 | + sym, params = net_original.export(None) |
| 163 | + |
| 164 | + if check_fp32_fusion: |
| 165 | + data_min = -1.0 |
| 166 | + data_max = 1.0 |
| 167 | + if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: |
| 168 | + check_quantization = False |
| 169 | + data_min = 0 |
| 170 | + |
| 171 | + sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, skip_infer=True) |
| 172 | + for name, attrs in attrs_dict.items(): |
| 173 | + if name in config: |
| 174 | + op_name = config[name][OP_NAME] |
| 175 | + else: |
| 176 | + op_name = name |
| 177 | + assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 |
| 178 | + if len(attrs): |
| 179 | + found = False |
| 180 | + for k, v in sym_sg.attr_dict().items(): |
| 181 | + if k.find(op_name) != -1: |
| 182 | + found = True |
| 183 | + for attr_name, attr_value in attrs.items(): |
| 184 | + assert v[attr_name].lower() == attr_value.lower() |
| 185 | + assert found |
| 186 | + |
| 187 | + data = mx.nd.random.uniform(shape=data_shape, low=data_min, high=data_max) |
| 188 | + out_unfused = net_original(data) |
| 189 | + |
| 190 | + net_fusion.optimize_for(data, backend=SG_PASS_NAME) |
| 191 | + out_fused = net_fusion(data) |
| 192 | + |
| 193 | + assert_almost_equal(out_unfused.asnumpy(), out_fused.asnumpy(), rtol=1e-3, atol=1e-1) |
| 194 | + |
| 195 | + if check_quantization: |
| 196 | + # fp32 to int8 |
| 197 | + for out_type in out_types: |
| 198 | + check_quantize(net_original, data_shape, out_type, name=name) |
| 199 | + |
| 200 | +def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None, |
| 201 | + data_shapes=(4,4,10,10), name='conv'): |
| 202 | + op_name = config[name][OP_NAME] |
| 203 | + |
| 204 | + data_nd = mx.nd.random.uniform(shape=data_shapes) |
| 205 | + net_original.initialize() |
| 206 | + net_original.hybridize() |
| 207 | + net_original(data_nd) |
| 208 | + |
| 209 | + sym, _ = net_original.export(None) |
| 210 | + sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) |
| 211 | + |
| 212 | + attrs_dict = sym_sg.attr_dict() |
| 213 | + for k, v in attrs_dict.items(): |
| 214 | + if k.find(op_name) != -1: |
| 215 | + for attr in attrs_name: |
| 216 | + assert v[attr] == 'true' |
| 217 | + for exc_attr in excluded_attrs: |
| 218 | + assert exc_attr not in v.keys() |
0 commit comments