1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/eval_peephole.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4 #include <torch/torch.h>
5
6 #include <c10/util/irange.h>
7 #include <algorithm>
8
9 namespace torch::jit {
10
11 namespace onnx {
12 using namespace ::c10::onnx;
13 }
14
getValues(Node * node,const ValueToParamPairMap & valsToParamsMap)15 std::vector<at::Tensor> getValues(
16 Node* node,
17 const ValueToParamPairMap& valsToParamsMap) {
18 size_t numInputs = node->inputs().size();
19 std::vector<at::Tensor> inputTensorValues;
20 inputTensorValues.reserve(numInputs);
21 for (auto val : node->inputs()) {
22 if (val->node()->kind() == prim::Param) {
23 auto itr = valsToParamsMap.find(val);
24 if (itr == valsToParamsMap.end()) {
25 continue;
26 }
27 inputTensorValues.push_back(itr->second.second.toTensor());
28 } else if (val->node()->kind() == onnx::Constant) {
29 inputTensorValues.push_back(val->node()->t(attr::value));
30 } else {
31 continue;
32 }
33 }
34 return inputTensorValues;
35 }
36
37 // This pass fuses Conv and BatchNorm into Conv node
38 // Conv and BatchNorm can be fused only if inputs for BatchNorm node:
39 // scale, bias, mean and var are all tensors of same shape (C) and
40 // if the size of the first dimension (dim 0) is the same between Conv
41 // input weight and BatchNorm input scale.
fuseConvBatchNorm(Block * b,ValueToParamPairMap & valsToParamsMap)42 static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
43 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
44 for (auto* child_block : it->blocks()) {
45 fuseConvBatchNorm(child_block, valsToParamsMap);
46 }
47 if (it->kind() == onnx::Conv) {
48 auto oldConv = *it;
49 if (oldConv->outputs().at(0)->uses().size() != 1) {
50 continue;
51 }
52 auto bnNode = oldConv->outputs().at(0)->uses()[0].user;
53 if (bnNode->kind() != onnx::BatchNormalization) {
54 continue;
55 }
56
57 if (oldConv->outputs().size() !=
58 bnNode->outputs().size()) { // BN layer is not in eval mode
59 continue;
60 }
61
62 auto epsilon = bnNode->f(attr::epsilon);
63 auto convInputVals = getValues(oldConv, valsToParamsMap);
64 if (convInputVals.empty() ||
65 (oldConv->inputs().size() == 3 && convInputVals.size() != 2)) {
66 continue;
67 }
68
69 auto bnInputVals = getValues(bnNode, valsToParamsMap);
70 if (bnInputVals.size() != 4) {
71 continue;
72 }
73
74 // See
75 // https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization
76 auto bnScale = bnInputVals[0].clone();
77 auto bnB = bnInputVals[1].clone();
78 auto bnMean = bnInputVals[2].clone();
79 auto bnVar = bnInputVals[3].clone();
80 // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
81 auto convW = convInputVals[0].clone();
82 at::Tensor convB;
83
84 if (!bnScale.is_floating_point() || !bnB.is_floating_point() ||
85 !bnMean.is_floating_point() || !bnVar.is_floating_point() ||
86 !convW.is_floating_point() || bnScale.dim() != 1 || bnB.dim() != 1 ||
87 bnMean.dim() != 1 || bnVar.dim() != 1 ||
88 !(bnScale.size(0) == bnB.size(0)) ||
89 !(bnB.size(0) == bnMean.size(0)) ||
90 !(bnMean.size(0) == bnVar.size(0)) || !(convW.dim() > 2) ||
91 !(convW.size(0) == bnScale.size(0))) {
92 continue;
93 }
94
95 bnVar = bnVar.add(epsilon);
96 bnVar = bnVar.sqrt();
97 bnScale = bnScale.div(bnVar);
98
99 // Calculate weight
100 for (const auto i : c10::irange(convW.size(0))) {
101 convW[i] = convW[i].mul(bnScale[i]);
102 }
103
104 // Calculate bias
105 if (oldConv->inputs().size() == 3) {
106 convB = convInputVals[1].clone();
107 convB = convB.sub(bnMean);
108 convB = convB.mul(bnScale);
109 convB = convB.add(bnB);
110 } else {
111 bnMean = bnMean.mul(bnScale);
112 bnB = bnB.sub(bnMean);
113 convB = bnB;
114 }
115
116 Node* newConv = b->owningGraph()->create(onnx::Conv, 1);
117 newConv->outputs().at(0)->copyMetadata(bnNode->outputs().at(0));
118
119 newConv->copyAttributes(*oldConv);
120 newConv->insertBefore(bnNode);
121 newConv->addInput(oldConv->inputs().at(0));
122 newConv->copyMetadata(oldConv);
123
124 auto newConvW = b->owningGraph()->addInput();
125 valsToParamsMap.insert(
126 {newConvW, std::make_pair(newConvW->debugName(), convW)});
127 newConvW->inferTypeFrom(convW);
128 newConv->addInput(newConvW);
129
130 auto newConvB = b->owningGraph()->addInput();
131 valsToParamsMap.insert(
132 {newConvB, std::make_pair(newConvB->debugName(), convB)});
133 newConvB->inferTypeFrom(convB);
134 newConv->addInput(newConvB);
135
136 bnNode->outputs().at(0)->replaceAllUsesWith(newConv->outputs().at(0));
137 bnNode->destroy();
138 it.destroyCurrent();
139 }
140 }
141 }
142
EvalPeepholeONNX(Block * b,ParamMap & paramsDict)143 void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
144 auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
145 fuseConvBatchNorm(b, valsToParamsMap);
146 buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
147 }
148
EvalPeepholeONNX(std::shared_ptr<Graph> & g,ParamMap & paramsDict)149 void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
150 EvalPeepholeONNX(g->block(), paramsDict);
151 GRAPH_DUMP("After EvalPeepholeONNX:", g);
152 }
153
154 } // namespace torch::jit
155