Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5cbcbce

Browse files
bgawrychBartlomiej Gawrych
andauthored
Remove identity operators from oneDNN optimized graph (#20712)
* Remove identity operators from inference graph * Add new line at EOF * review fixes * Small refactor & review * remove commented fragment Co-authored-by: Bartlomiej Gawrych <barlomiej.gawrych@intel.com>
1 parent ebc88e7 commit 5cbcbce

4 files changed

Lines changed: 195 additions & 1 deletion

File tree

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file dnnl_identity_property.cc
22+
* \brief Graph property for removing identity operators
23+
*/
24+
25+
#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_IDENTITY_PROPERTY_H_
26+
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_IDENTITY_PROPERTY_H_
27+
#if MXNET_USE_ONEDNN == 1
28+
29+
#include <map>
30+
#include <string>
31+
#include <vector>
32+
33+
#include "../common.h"
34+
#include "../../nn/dropout-inl.h"
35+
#include "dnnl_subgraph_base-inl.h"
36+
37+
namespace mxnet {
38+
namespace op {
39+
40+
class SgDNNLIdentitySelector : public SubgraphSelectorV2 {
41+
private:
42+
std::vector<const BiDirectedNode*> matched_list_;
43+
44+
public:
45+
bool Select(const BiDirectedNode& seed_node,
46+
const std::shared_ptr<NodeAttr>& node_attr) override {
47+
bool status = false;
48+
if (seed_node.node->op() == Op::Get("_npi_copy")) {
49+
status = true;
50+
}
51+
52+
if (seed_node.node->op() == Op::Get("Dropout")) {
53+
auto const& dropout_param = nnvm::get<DropoutParam>(seed_node.node->attrs.parsed);
54+
if (dropout_param.mode == dropout::kTraining) {
55+
status = true;
56+
}
57+
}
58+
59+
if (status) {
60+
matched_list_.clear();
61+
matched_list_.emplace_back(&seed_node);
62+
return true;
63+
}
64+
return false;
65+
}
66+
67+
bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
68+
if (input_node.node->is_variable()) {
69+
return false;
70+
} else if (input_node.node->op()) {
71+
matched_list_.emplace_back(&input_node);
72+
return true;
73+
}
74+
return false;
75+
}
76+
77+
bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
78+
return false;
79+
}
80+
81+
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
82+
// candidates should contain only two nodes - custom node and identity node
83+
if (candidates.size() == 2 && candidates.size() == matched_list_.size()) {
84+
return candidates;
85+
} else {
86+
return std::vector<BiDirectedNode*>(0);
87+
}
88+
}
89+
90+
void Reset() override {
91+
CHECK_GE(matched_list_.size(), 1);
92+
auto new_selector = SgDNNLIdentitySelector();
93+
new_selector.Select(*matched_list_[0], nullptr);
94+
*this = new_selector;
95+
}
96+
};
97+
98+
inline bool IsIdentityNode(const nnvm::ObjectPtr node) {
99+
return node->op() && (node->op() == Op::Get("_npi_copy") || node->op() == Op::Get("Dropout"));
100+
}
101+
102+
class SgDNNLIdentityProperty : public SubgraphProperty {
103+
public:
104+
SgDNNLIdentityProperty() {}
105+
106+
static SubgraphPropertyPtr Create() {
107+
static const std::string& name = "DNNL Identity optimization pass";
108+
auto property = std::make_shared<SgDNNLIdentityProperty>();
109+
property->SetAttr<std::string>("property_name", name);
110+
property->SetAttr<bool>("inference_only", true);
111+
return property;
112+
}
113+
114+
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
115+
const int subgraph_id = 0) const override {
116+
nnvm::NodeEntry identity_node_entry;
117+
for (auto entry : sym.outputs) {
118+
if (IsIdentityNode(entry.node)) {
119+
identity_node_entry = entry;
120+
}
121+
}
122+
123+
auto last_node = identity_node_entry.node;
124+
nnvm::Symbol new_sym;
125+
new_sym.outputs.emplace_back(last_node);
126+
127+
nnvm::ObjectPtr org_node;
128+
DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
129+
if (!IsIdentityNode(node)) {
130+
org_node = node;
131+
}
132+
});
133+
134+
// Create copy of original node
135+
nnvm::ObjectPtr n = nnvm::Node::Create();
136+
n->attrs = org_node->attrs;
137+
CHECK(n->op());
138+
n->op()->attr_parser(&(n->attrs));
139+
return n;
140+
}
141+
142+
void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
143+
std::vector<nnvm::NodeEntry*>* output_entries) const override {
144+
// output of identity must be connected as output of operator before identity
145+
// e.g. for: /--index 0--> custom_op
146+
// (n) slice
147+
// \--index 1--> Dropout --index 0--> OUT_NODE
148+
// for OUT_NODE index 0 must be changed to index 1
149+
for (int i = 0; i < output_entries->size(); ++i) {
150+
auto out_node = output_entries->at(i)->node;
151+
if (IsIdentityNode(out_node)) {
152+
output_entries->at(i)->index = out_node->inputs[0].index;
153+
}
154+
output_entries->at(i)->node = n;
155+
}
156+
}
157+
158+
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
159+
auto selector = std::make_shared<SgDNNLIdentitySelector>();
160+
return selector;
161+
}
162+
};
163+
164+
} // namespace op
165+
} // namespace mxnet
166+
167+
#endif // if MXNET_USE_ONEDNN == 1
168+
#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_IDENTITY_PROPERTY_H_

src/operator/subgraph/dnnl/dnnl_subgraph_base-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ static inline bool SupportDNNLAttr(const std::shared_ptr<NodeAttr>& node_attr) {
3131
return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
3232
(node_attr->itype[0] == mshadow::kFloat32 ||
3333
node_attr->itype[0] == mshadow::kBfloat16) &&
34-
(ndim == 1 || ndim == 2 || ndim == 4 || ndim == 5);
34+
(ndim >= 1 && ndim <= 5);
3535
} else {
3636
return true;
3737
}

src/operator/subgraph/dnnl/dnnl_subgraph_property.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "dnnl_bn_relu_property.h"
2424
#include "dnnl_conv_property.h"
2525
#include "dnnl_fc_property.h"
26+
#include "dnnl_identity_property.h"
2627
#include "dnnl_post_quantize_align_scale_property.h"
2728
#include "dnnl_post_quantize_property.h"
2829
#include "dnnl_transformer_qk_property.h"
@@ -35,6 +36,7 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN)
3536
.set_attr("enable", DNNLEnvSet())
3637
.set_attr("context", Context::CPU());
3738

39+
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLIdentityProperty);
3840
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLConvProperty);
3941
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty);
4042
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty);
@@ -44,6 +46,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
4446

4547
MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", Context::CPU());
4648

49+
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLIdentityProperty);
4750
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLConvProperty).set_attr("quantize", true);
4851
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCProperty).set_attr("quantize", true);
4952
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerQKProperty);

tests/python/dnnl/subgraphs/test_fc_subgraph.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,26 @@ def forward(self, x):
200200
attrs = {'fc': {}}
201201
net = MultiOutputFC()
202202
check_fusion(net, data_shape, attrs, check_quantization=flatten)
203+
204+
205+
@mx.util.use_np
206+
@pytest.mark.parametrize('identity_node', ['dropout', 'copy'])
207+
def test_fc_identity_eltwise(identity_node):
208+
class FCIdentityEltwise(nn.HybridBlock):
209+
def __init__(self, identity_node, **kwargs):
210+
super(FCIdentityEltwise, self).__init__(**kwargs)
211+
self.fc = nn.Dense(units=64, use_bias=False, weight_initializer=None, flatten=True)
212+
self.identity_node = identity_node
213+
def forward(self, x):
214+
fc_out = self.fc(x)
215+
if self.identity_node == 'copy':
216+
fc_out = mx.np.copy(fc_out)
217+
else:
218+
fc_out = mx.npx.dropout(fc_out)
219+
out = mx.npx.activation(fc_out, act_type='relu')
220+
return out
221+
222+
data_shape = (64, 4, 10, 10)
223+
attrs = {'fc': {'with_eltwise': 'true'}}
224+
net = FCIdentityEltwise(identity_node)
225+
check_fusion(net, data_shape, attrs, check_quantization=False)

0 commit comments

Comments
 (0)