xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/eval_peephole.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/jit_log.h>
2 #include <torch/csrc/jit/passes/onnx/eval_peephole.h>
3 #include <torch/csrc/jit/passes/onnx/helper.h>
4 #include <torch/torch.h>
5 
6 #include <c10/util/irange.h>
7 #include <algorithm>
8 
9 namespace torch::jit {
10 
11 namespace onnx {
12 using namespace ::c10::onnx;
13 }
14 
getValues(Node * node,const ValueToParamPairMap & valsToParamsMap)15 std::vector<at::Tensor> getValues(
16     Node* node,
17     const ValueToParamPairMap& valsToParamsMap) {
18   size_t numInputs = node->inputs().size();
19   std::vector<at::Tensor> inputTensorValues;
20   inputTensorValues.reserve(numInputs);
21   for (auto val : node->inputs()) {
22     if (val->node()->kind() == prim::Param) {
23       auto itr = valsToParamsMap.find(val);
24       if (itr == valsToParamsMap.end()) {
25         continue;
26       }
27       inputTensorValues.push_back(itr->second.second.toTensor());
28     } else if (val->node()->kind() == onnx::Constant) {
29       inputTensorValues.push_back(val->node()->t(attr::value));
30     } else {
31       continue;
32     }
33   }
34   return inputTensorValues;
35 }
36 
37 // This pass fuses Conv and BatchNorm into Conv node
38 // Conv and BatchNorm can be fused only if inputs for BatchNorm node:
39 // scale, bias, mean and var are all tensors of same shape (C) and
40 // if the size of the first dimension (dim 0) is the same between Conv
41 // input weight and BatchNorm input scale.
fuseConvBatchNorm(Block * b,ValueToParamPairMap & valsToParamsMap)42 static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
43   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
44     for (auto* child_block : it->blocks()) {
45       fuseConvBatchNorm(child_block, valsToParamsMap);
46     }
47     if (it->kind() == onnx::Conv) {
48       auto oldConv = *it;
49       if (oldConv->outputs().at(0)->uses().size() != 1) {
50         continue;
51       }
52       auto bnNode = oldConv->outputs().at(0)->uses()[0].user;
53       if (bnNode->kind() != onnx::BatchNormalization) {
54         continue;
55       }
56 
57       if (oldConv->outputs().size() !=
58           bnNode->outputs().size()) { // BN layer is not in eval mode
59         continue;
60       }
61 
62       auto epsilon = bnNode->f(attr::epsilon);
63       auto convInputVals = getValues(oldConv, valsToParamsMap);
64       if (convInputVals.empty() ||
65           (oldConv->inputs().size() == 3 && convInputVals.size() != 2)) {
66         continue;
67       }
68 
69       auto bnInputVals = getValues(bnNode, valsToParamsMap);
70       if (bnInputVals.size() != 4) {
71         continue;
72       }
73 
74       // See
75       // https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization
76       auto bnScale = bnInputVals[0].clone();
77       auto bnB = bnInputVals[1].clone();
78       auto bnMean = bnInputVals[2].clone();
79       auto bnVar = bnInputVals[3].clone();
80       // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
81       auto convW = convInputVals[0].clone();
82       at::Tensor convB;
83 
84       if (!bnScale.is_floating_point() || !bnB.is_floating_point() ||
85           !bnMean.is_floating_point() || !bnVar.is_floating_point() ||
86           !convW.is_floating_point() || bnScale.dim() != 1 || bnB.dim() != 1 ||
87           bnMean.dim() != 1 || bnVar.dim() != 1 ||
88           !(bnScale.size(0) == bnB.size(0)) ||
89           !(bnB.size(0) == bnMean.size(0)) ||
90           !(bnMean.size(0) == bnVar.size(0)) || !(convW.dim() > 2) ||
91           !(convW.size(0) == bnScale.size(0))) {
92         continue;
93       }
94 
95       bnVar = bnVar.add(epsilon);
96       bnVar = bnVar.sqrt();
97       bnScale = bnScale.div(bnVar);
98 
99       // Calculate weight
100       for (const auto i : c10::irange(convW.size(0))) {
101         convW[i] = convW[i].mul(bnScale[i]);
102       }
103 
104       // Calculate bias
105       if (oldConv->inputs().size() == 3) {
106         convB = convInputVals[1].clone();
107         convB = convB.sub(bnMean);
108         convB = convB.mul(bnScale);
109         convB = convB.add(bnB);
110       } else {
111         bnMean = bnMean.mul(bnScale);
112         bnB = bnB.sub(bnMean);
113         convB = bnB;
114       }
115 
116       Node* newConv = b->owningGraph()->create(onnx::Conv, 1);
117       newConv->outputs().at(0)->copyMetadata(bnNode->outputs().at(0));
118 
119       newConv->copyAttributes(*oldConv);
120       newConv->insertBefore(bnNode);
121       newConv->addInput(oldConv->inputs().at(0));
122       newConv->copyMetadata(oldConv);
123 
124       auto newConvW = b->owningGraph()->addInput();
125       valsToParamsMap.insert(
126           {newConvW, std::make_pair(newConvW->debugName(), convW)});
127       newConvW->inferTypeFrom(convW);
128       newConv->addInput(newConvW);
129 
130       auto newConvB = b->owningGraph()->addInput();
131       valsToParamsMap.insert(
132           {newConvB, std::make_pair(newConvB->debugName(), convB)});
133       newConvB->inferTypeFrom(convB);
134       newConv->addInput(newConvB);
135 
136       bnNode->outputs().at(0)->replaceAllUsesWith(newConv->outputs().at(0));
137       bnNode->destroy();
138       it.destroyCurrent();
139     }
140   }
141 }
142 
EvalPeepholeONNX(Block * b,ParamMap & paramsDict)143 void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
144   auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
145   fuseConvBatchNorm(b, valsToParamsMap);
146   buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
147 }
148 
EvalPeepholeONNX(std::shared_ptr<Graph> & g,ParamMap & paramsDict)149 void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
150   EvalPeepholeONNX(g->block(), paramsDict);
151   GRAPH_DUMP("After EvalPeepholeONNX:", g);
152 }
153 
154 } // namespace torch::jit
155