Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7b750a8

Browse files
author
bgawrych
authored
[REFACTOR] Refactor tests/python/mkl/test_subgraph.py to utilize Gluon API (#19675)
* Refactor tests/python/mkl/test_subgraph.py to utilize Gluon API * Split oneDNN subgraph tests to multiple files * Add license to files * Fix review * Fix imports * Fix imports 2
1 parent 417accc commit 7b750a8

4 files changed

Lines changed: 1104 additions & 976 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import mxnet as mx
19+
import numpy as np
20+
import unittest
21+
import pytest
22+
import ctypes
23+
import copy
24+
25+
26+
from mxnet.contrib import quantization
27+
from mxnet.gluon import nn
28+
from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err
29+
30+
OP_NAME='op_name'
31+
QUANTIZED_OP_NAME='quantized_op_name'
32+
SG_PASS_NAME='MKLDNN'
33+
QUANTIZE_SG_PASS_NAME='MKLDNN_QUANTIZE'
34+
config = {
35+
'conv': {
36+
OP_NAME: 'sg_mkldnn_conv',
37+
QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv'
38+
},
39+
'fc': {
40+
OP_NAME: 'sg_mkldnn_fully_connected',
41+
QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected'
42+
}
43+
}
44+
45+
DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)]
46+
47+
# Helpers
48+
class RELU6(nn.HybridBlock):
49+
"""Relu6 used in MobileNetV2."""
50+
51+
def __init__(self, **kwargs):
52+
super(RELU6, self).__init__(**kwargs)
53+
54+
def hybrid_forward(self, F, x):
55+
return F.clip(x, 0, 6, name="relu6")
56+
57+
class TailNegBlock(nn.HybridBlock):
58+
def __init__(self, **kwargs):
59+
super(TailNegBlock, self).__init__(**kwargs)
60+
self.fc1 = nn.Dense(10, flatten=True)
61+
self.fc2 = nn.Dense(10, flatten=True)
62+
63+
def hybrid_forward(self, F, x1, x2):
64+
out_fc1 = self.fc1(x1)
65+
out_fc2 = self.fc2(x2)
66+
out = F.concat(out_fc1, out_fc2)
67+
out = F.softmax(out)
68+
return out
69+
70+
class CustomNormalInit(mx.init.Initializer):
71+
"""Initializes weights with random values sampled from a normal distribution
72+
with a custom mean and standard deviation of `sigma`.
73+
"""
74+
def __init__(self, mean=0, sigma=0.01):
75+
super(CustomNormalInit, self).__init__(mean=mean, sigma=sigma)
76+
self.mean = mean
77+
self.sigma = sigma
78+
79+
def _init_weight(self, _, arr):
80+
mx.random.normal(self.mean, self.sigma, arr.shape, dtype=arr.dtype, out=arr)
81+
82+
83+
def check_qsym_calibrated(qsym, out_type, name='conv'):
84+
quantized_op_name = 'quantized_' + name
85+
assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1
86+
for k, v in qsym.attr_dict().items():
87+
if k.find('_quantize') != -1:
88+
assert v['out_type'] == out_type
89+
if k.find(quantized_op_name) != -1:
90+
if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v:
91+
continue
92+
assert 'min_calib_range' in v
93+
assert 'max_calib_range' in v
94+
95+
def check_qsym_scale_align(qsym):
96+
assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1
97+
init = False
98+
for k, v in qsym.attr_dict().items():
99+
if k.find('quantized_sg_mkldnn_conv') != -1:
100+
assert 'min_calib_range' in v
101+
assert 'max_calib_range' in v
102+
if not init:
103+
min_calib_range = v['min_calib_range']
104+
max_calib_range = v['max_calib_range']
105+
init = True
106+
else:
107+
assert min_calib_range == v['min_calib_range']
108+
assert max_calib_range == v['max_calib_range']
109+
110+
111+
def check_quantize(net_original, data_shape, out_type, name='conv',
112+
check_calibration=True, check_scale_align=False):
113+
quantize_granularity_list = ['tensor-wise']
114+
if name == 'fc':
115+
quantize_granularity_list += ['channel-wise']
116+
117+
if name in config:
118+
name = config[name][OP_NAME]
119+
120+
net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
121+
min_value = -1 if out_type != 'uint8' else 0
122+
data = mx.random.uniform(min_value, 1.0, shape=data_shape, dtype='float32', ctx=mx.current_context())
123+
124+
outputs = net_original(data)
125+
for output in outputs:
126+
output.wait_to_read()
127+
ref_out = outputs
128+
129+
calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
130+
for quantize_granularity in quantize_granularity_list:
131+
qnet = quantization.quantize_net(net_original,
132+
ctx=mx.current_context(),
133+
exclude_layers=None,
134+
exclude_operators=None,
135+
quantized_dtype=out_type,
136+
calib_mode='naive',
137+
calib_data=calib_data,
138+
num_calib_batches=1,
139+
quantize_mode='full',
140+
quantize_granularity=quantize_granularity)
141+
qsym, _ = qnet.export(None)
142+
if check_calibration:
143+
check_qsym_calibrated(qsym, out_type, name=name)
144+
if check_scale_align:
145+
check_qsym_scale_align(qsym)
146+
147+
quantized_out = qnet(data)
148+
for i in range(len(ref_out)):
149+
min_range = mx.nd.min(ref_out[i]).asscalar()
150+
max_range = mx.nd.max(ref_out[i]).asscalar()
151+
atol = 0.1 * max(abs(min_range), abs(max_range))
152+
assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), rtol=0.1, atol=atol, etol=0.2)
153+
154+
155+
def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True,
156+
out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True):
157+
net_original.initialize()
158+
net_original.hybridize(static_alloc=False, static_shape=False)
159+
data = mx.random.uniform(shape=data_shape, dtype='float32', ctx=mx.current_context())
160+
net_original(data)
161+
net_fusion = copy.copy(net_original)
162+
sym, params = net_original.export(None)
163+
164+
if check_fp32_fusion:
165+
data_min = -1.0
166+
data_max = 1.0
167+
if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
168+
check_quantization = False
169+
data_min = 0
170+
171+
sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, skip_infer=True)
172+
for name, attrs in attrs_dict.items():
173+
if name in config:
174+
op_name = config[name][OP_NAME]
175+
else:
176+
op_name = name
177+
assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
178+
if len(attrs):
179+
found = False
180+
for k, v in sym_sg.attr_dict().items():
181+
if k.find(op_name) != -1:
182+
found = True
183+
for attr_name, attr_value in attrs.items():
184+
assert v[attr_name].lower() == attr_value.lower()
185+
assert found
186+
187+
data = mx.nd.random.uniform(shape=data_shape, low=data_min, high=data_max)
188+
out_unfused = net_original(data)
189+
190+
net_fusion.optimize_for(data, backend=SG_PASS_NAME)
191+
out_fused = net_fusion(data)
192+
193+
assert_almost_equal(out_unfused.asnumpy(), out_fused.asnumpy(), rtol=1e-3, atol=1e-1)
194+
195+
if check_quantization:
196+
# fp32 to int8
197+
for out_type in out_types:
198+
check_quantize(net_original, data_shape, out_type, name=name)
199+
200+
def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None,
201+
data_shapes=(4,4,10,10), name='conv'):
202+
op_name = config[name][OP_NAME]
203+
204+
data_nd = mx.nd.random.uniform(shape=data_shapes)
205+
net_original.initialize()
206+
net_original.hybridize()
207+
net_original(data_nd)
208+
209+
sym, _ = net_original.export(None)
210+
sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True)
211+
212+
attrs_dict = sym_sg.attr_dict()
213+
for k, v in attrs_dict.items():
214+
if k.find(op_name) != -1:
215+
for attr in attrs_name:
216+
assert v[attr] == 'true'
217+
for exc_attr in excluded_attrs:
218+
assert exc_attr not in v.keys()

0 commit comments

Comments
 (0)