xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 #include "tensorflow/core/common_runtime/graph_constructor.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/graph/subgraph.h"
20 #include "tensorflow/core/platform/init_main.h"
21 #include "tensorflow/core/public/session.h"
22 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24 
25 namespace tensorflow {
26 namespace graph_transforms {
27 namespace {
28 // Ensures the tensor is the expected shape.
ErrorIfNotVector(const Tensor & input,const string & input_name,int expected_width)29 Status ErrorIfNotVector(const Tensor& input, const string& input_name,
30                         int expected_width) {
31   if ((input.shape().dims() != 1) ||
32       (input.shape().dim_size(0) != expected_width)) {
33     return errors::InvalidArgument(
34         input_name,
35         " input to batch norm has bad shape: ", input.shape().DebugString());
36   }
37   return OkStatus();
38 }
39 
GetScaleAndOffsetValues(const NodeMatch & match,std::vector<float> * scale_values,std::vector<float> * offset_values)40 Status GetScaleAndOffsetValues(const NodeMatch& match,
41                                std::vector<float>* scale_values,
42                                std::vector<float>* offset_values) {
43   // Find all the nodes we expect in the subgraph.
44   const NodeDef& batch_norm_node = match.node;
45   // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ
46   // by input order and attribute names.
47   CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" ||
48         batch_norm_node.op() == "FusedBatchNorm");
49   const bool is_fused = batch_norm_node.op() == "FusedBatchNorm";
50   const int mean_idx = is_fused ? 3 : 1;
51   const int var_idx = is_fused ? 4 : 2;
52   const int beta_idx = is_fused ? 2 : 3;
53   const int gamma_idx = is_fused ? 1 : 4;
54   const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon";
55   // FusedBatchNorm always scales after normalization.
56   const bool scale_after_normalization =
57       is_fused || batch_norm_node.attr().at("scale_after_normalization").b();
58 
59   const NodeDef& mean_node = match.inputs[mean_idx].node;
60   CHECK_EQ("Const", mean_node.op());
61   const NodeDef& variance_node = match.inputs[var_idx].node;
62   CHECK_EQ("Const", variance_node.op());
63   const NodeDef& beta_node = match.inputs[beta_idx].node;
64   CHECK_EQ("Const", beta_node.op());
65   const NodeDef& gamma_node = match.inputs[gamma_idx].node;
66   CHECK_EQ("Const", gamma_node.op());
67 
68   // We have a set of vectors that we want to combine into a vector of
69   // scale values and offset values.
70   Tensor mean = GetNodeTensorAttr(mean_node, "value");
71   Tensor variance = GetNodeTensorAttr(variance_node, "value");
72   Tensor beta = GetNodeTensorAttr(beta_node, "value");
73   Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
74   const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f();
75 
76   // Make sure all the inputs really are vectors with the same shape.
77   const int64_t num_cols = mean.shape().dim_size(0);
78   TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols));
79   TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols));
80   TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols));
81 
82   scale_values->resize(num_cols);
83   offset_values->resize(num_cols);
84 
85   // Calculate the scale and offset values to apply.
86   if (scale_after_normalization) {
87     for (int i = 0; i < num_cols; ++i) {
88       (*scale_values)[i] =
89           (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
90           gamma.flat<float>()(i);
91     }
92   } else {
93     for (int i = 0; i < num_cols; ++i) {
94       (*scale_values)[i] =
95           (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
96     }
97   }
98   for (int i = 0; i < num_cols; ++i) {
99     (*offset_values)[i] =
100         (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i);
101   }
102   return OkStatus();
103 }
104 
FuseScaleOffsetToConvWeights(const std::vector<float> & scale_values,const std::vector<float> & offset_values,const NodeMatch & conv_node_match,const string & conv_output_name,std::vector<NodeDef> * new_nodes)105 Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values,
106                                     const std::vector<float>& offset_values,
107                                     const NodeMatch& conv_node_match,
108                                     const string& conv_output_name,
109                                     std::vector<NodeDef>* new_nodes) {
110   const NodeDef& conv_node = conv_node_match.node;
111   // CHECK_EQ("Conv2D", conv_node.op());
112   const NodeDef& input_node = conv_node_match.inputs[0].node;
113   const NodeDef& weights_node = conv_node_match.inputs[1].node;
114   CHECK_EQ("Const", weights_node.op());
115 
116   Tensor weights = GetNodeTensorAttr(weights_node, "value");
117   int64_t weights_cols;
118   if (conv_node.op() == "Conv2D") {
119     weights_cols = weights.shape().dim_size(3);
120   } else if (conv_node.op() == "DepthwiseConv2dNative") {
121     weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3);
122   } else {
123     weights_cols = weights.shape().dim_size(1);
124   }
125   CHECK_EQ(weights_cols, scale_values.size());
126 
127   // Multiply the original weights by the scale vector.
128   auto weights_vector = weights.flat<float>();
129   Tensor scaled_weights(DT_FLOAT, weights.shape());
130   auto scaled_weights_vector = scaled_weights.flat<float>();
131   for (int64_t row = 0; row < weights_vector.dimension(0); ++row) {
132     scaled_weights_vector(row) =
133         weights_vector(row) * scale_values[row % weights_cols];
134   }
135   // Figure out the remaining bias to add on.
136   Tensor bias_offset(DT_FLOAT, {weights_cols});
137   auto bias_offset_vector = bias_offset.flat<float>();
138   for (int64_t col = 0; col < weights_cols; ++col) {
139     bias_offset_vector(col) = offset_values[col];
140   }
141 
142   // Construct the new nodes.
143   NodeDef scaled_weights_node;
144   scaled_weights_node.set_op("Const");
145   scaled_weights_node.set_name(weights_node.name());
146   SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
147   SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
148   new_nodes->push_back(scaled_weights_node);
149 
150   // The input and convolution can be copied straight over, since the
151   // name of the scaled weights constant is the same as the original.
152   new_nodes->push_back(input_node);
153   new_nodes->push_back(conv_node);
154 
155   NodeDef bias_offset_node;
156   bias_offset_node.set_op("Const");
157   bias_offset_node.set_name(conv_node.name() + "_bn_offset");
158   SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
159   SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
160   new_nodes->push_back(bias_offset_node);
161 
162   NodeDef bias_add_node;
163   bias_add_node.set_op("BiasAdd");
164   bias_add_node.set_name(conv_output_name);
165   if (conv_node.attr().count("data_format")) {
166     CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node);
167   }
168   CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
169   AddNodeInput(conv_node.name(), &bias_add_node);
170   AddNodeInput(bias_offset_node.name(), &bias_add_node);
171   new_nodes->push_back(bias_add_node);
172   return OkStatus();
173 }
174 
FuseBatchNormWithConv(const NodeMatch & match,std::vector<NodeDef> * new_nodes)175 Status FuseBatchNormWithConv(const NodeMatch& match,
176                              std::vector<NodeDef>* new_nodes) {
177   // Calculate the scale and offset values to apply.
178   std::vector<float> scale_values;
179   std::vector<float> offset_values;
180   TF_RETURN_IF_ERROR(
181       GetScaleAndOffsetValues(match, &scale_values, &offset_values));
182 
183   // Fuse conv weights, and set the final output node name as batch_norm_node.
184   const NodeDef& batch_norm_node = match.node;
185   TF_RETURN_IF_ERROR(
186       FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0],
187                                    batch_norm_node.name(), new_nodes));
188   return OkStatus();
189 }
190 
FuseBatchNormWithBatchToSpace(const NodeMatch & match,std::vector<NodeDef> * new_nodes)191 Status FuseBatchNormWithBatchToSpace(const NodeMatch& match,
192                                      std::vector<NodeDef>* new_nodes) {
193   // Calculate the scale and offset values to apply.
194   std::vector<float> scale_values;
195   std::vector<float> offset_values;
196   TF_RETURN_IF_ERROR(
197       GetScaleAndOffsetValues(match, &scale_values, &offset_values));
198 
199   // Fuse conv weights, and set the final output node name as batch_norm_node.
200   const NodeDef& batch_norm_node = match.node;
201   const NodeMatch& batch_to_space_node_match = match.inputs[0];
202   const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0];
203   const NodeDef& batch_to_space_node = batch_to_space_node_match.node;
204   const NodeDef& conv_node = conv_node_match.node;
205 
206   string biasadd_name = conv_node.name() + "/biasadd";
207   TF_RETURN_IF_ERROR(FuseScaleOffsetToConvWeights(
208       scale_values, offset_values, conv_node_match, biasadd_name, new_nodes));
209 
210   NodeDef new_batch_to_space_node = batch_to_space_node;
211   // reuse batch_norm node name
212   new_batch_to_space_node.set_name(batch_norm_node.name());
213   new_batch_to_space_node.set_input(0, biasadd_name);
214   new_nodes->push_back(batch_to_space_node_match.inputs[1].node);
215   new_nodes->push_back(batch_to_space_node_match.inputs[2].node);
216   new_nodes->push_back(new_batch_to_space_node);
217   return OkStatus();
218 }
219 
FuseBatchNormWithConvConcat(const NodeMatch & match,std::vector<NodeDef> * new_nodes)220 Status FuseBatchNormWithConvConcat(const NodeMatch& match,
221                                    std::vector<NodeDef>* new_nodes) {
222   // Calculate the scale and offset values to apply.
223   std::vector<float> scale_values;
224   std::vector<float> offset_values;
225   TF_RETURN_IF_ERROR(
226       GetScaleAndOffsetValues(match, &scale_values, &offset_values));
227 
228   // Find all the nodes we expect in the subgraph.
229   const NodeDef& batch_norm_node = match.node;
230   const NodeMatch& concat_node_match = match.inputs[0];
231   NodeDef concat_node = concat_node_match.node;
232   CHECK_EQ("ConcatV2", concat_node.op());
233 
234   // First process the axis.
235   NodeDef axis_node = concat_node_match.inputs[2].node;
236   CHECK_EQ("Const", axis_node.op());
237   Tensor axis = GetNodeTensorAttr(axis_node, "value");
238   int32_t axis_scalar = (axis.scalar<int32>())();
239 
240   // Set both conv0 and conv1 have the same scale and offset in default.
241   std::vector<float> scale0(scale_values);
242   std::vector<float> offset0(offset_values);
243   std::vector<float> scale1(scale_values);
244   std::vector<float> offset1(offset_values);
245   if (axis_scalar == 3) {
246     // If axis is 3, then scale and offset will be split into two halfs.
247     const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node;
248     Tensor weights0 = GetNodeTensorAttr(weights0_node, "value");
249     const int64_t split_cols = weights0.shape().dim_size(3);
250     // Only keep the first half for scale0/offset0.
251     scale0.erase(scale0.begin() + split_cols, scale0.end());
252     offset0.erase(offset0.begin() + split_cols, offset0.end());
253     // Only keep the second half for scale1/offset1.
254     scale1.erase(scale1.begin(), scale1.begin() + split_cols);
255     offset1.erase(offset1.begin(), offset1.begin() + split_cols);
256   }
257 
258   // Fuse the weights for input0 of conv2d.
259   const string concat0_output_name = concat_node.name() + "_bn_in0";
260   TF_RETURN_IF_ERROR(
261       FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0],
262                                    concat0_output_name, new_nodes));
263 
264   // Fuse the weights for input1 of conv2d.
265   const string concat1_output_name = concat_node.name() + "_bn_in1";
266   TF_RETURN_IF_ERROR(
267       FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1],
268                                    concat1_output_name, new_nodes));
269 
270   // Push the shape node.
271   new_nodes->push_back(concat_node_match.inputs[2].node);
272 
273   // Set the final output op name to batch_normal_node.
274   concat_node.set_name(batch_norm_node.name());
275   concat_node.set_input(0, concat0_output_name);
276   concat_node.set_input(1, concat1_output_name);
277   new_nodes->push_back(concat_node);
278   return OkStatus();
279 }
280 }  // namespace
281 
282 // Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
283 // converts them into premultiplied weight inputs to convolutions.
FoldOldBatchNorms(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)284 Status FoldOldBatchNorms(const GraphDef& input_graph_def,
285                          const TransformFuncContext& context,
286                          GraphDef* output_graph_def) {
287   GraphDef current_graph_def = input_graph_def;
288   // We have to do several passes to catch all the old BN nodes, since many of
289   // them may share inputs and so be excluded from replacement in one pass.
290   bool did_graph_change;
291   do {
292     did_graph_change = false;
293     GraphDef replaced_graph_def;
294     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
295         current_graph_def,  // clang-format off
296       {"BatchNormWithGlobalNormalization|FusedBatchNorm",    // batch_norm_node
297         {
298           {"Conv2D|DepthwiseConv2dNative",                          // conv_node
299             {
300               {"*"},                          // input_node
301               {"Const"},                      // weights_node
302             }
303           },
304           {"Const"},                          // mean_node
305           {"Const"},                          // variance_node
306           {"Const"},                          // beta_node
307           {"Const"},                          // gamma_node
308         }
309       },  // clang-format on
310         [&did_graph_change](const NodeMatch& match,
311                             const std::set<string>& input_nodes,
312                             const std::set<string>& output_nodes,
313                             std::vector<NodeDef>* new_nodes) {
314           TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes));
315           did_graph_change = true;
316           return OkStatus();
317         },
318         {}, &replaced_graph_def));
319     current_graph_def = replaced_graph_def;
320   } while (did_graph_change);
321 
322   do {
323     did_graph_change = false;
324     GraphDef replaced_graph_def;
325     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
326         current_graph_def,  // clang-format off
327         {"BatchNormWithGlobalNormalization|FusedBatchNorm",    // batch_norm_node
328          {
329              {"BatchToSpaceND",                  // batch_to_space_node
330               {
331                   {"Conv2D|DepthwiseConv2dNative",                     // conv_node
332                    {
333                        {"*"},                    // input_node
334                        {"Const"},                // weights_node
335                    }
336                   },
337                   {"Const"},                     // block_shape
338                   {"Const"},                     // crops
339               }
340              },
341              {"Const"},                          // mean_node
342              {"Const"},                          // variance_node
343              {"Const"},                          // beta_node
344              {"Const"},                          // gamma_node
345          }
346         },  // clang-format on
347         [&did_graph_change](const NodeMatch& match,
348                             const std::set<string>& input_nodes,
349                             const std::set<string>& output_nodes,
350                             std::vector<NodeDef>* new_nodes) {
351           TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes));
352           did_graph_change = true;
353           return OkStatus();
354         },
355         {}, &replaced_graph_def));
356     current_graph_def = replaced_graph_def;
357   } while (did_graph_change);
358 
359   do {
360     did_graph_change = false;
361     GraphDef replaced_graph_def;
362     // Replace BatchNorm with concat as input.
363     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
364         current_graph_def,  // clang-format off
365       {"BatchNormWithGlobalNormalization|FusedBatchNorm",    // batch_norm_node
366         {
367           {"ConcatV2|Concat",                     // concat two conv2d.
368             {
369               {"Conv2D|DepthwiseConv2dNative",                          // conv_node
370                 {
371                   {"*"},                          // input_node
372                   {"Const"},                      // weights_node
373                 }
374               },
375               {"Conv2D|DepthwiseConv2dNative",                          // conv_node
376                 {
377                   {"*"},                          // input_node
378                   {"Const"},                      // weights_node
379                 }
380               },
381               {"Const"},                          // axis
382             },
383           },
384           {"Const"},                          // mean_node
385           {"Const"},                          // variance_node
386           {"Const"},                          // beta_node
387           {"Const"},                          // gamma_node
388         }
389       },  // clang-format on
390         [&did_graph_change](const NodeMatch& match,
391                             const std::set<string>& input_nodes,
392                             const std::set<string>& output_nodes,
393                             std::vector<NodeDef>* new_nodes) {
394           TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes));
395           did_graph_change = true;
396           return OkStatus();
397         },
398         {}, &replaced_graph_def));
399     current_graph_def = replaced_graph_def;
400   } while (did_graph_change);
401 
402   *output_graph_def = current_graph_def;
403   return OkStatus();
404 }
405 
406 REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms);
407 
408 }  // namespace graph_transforms
409 }  // namespace tensorflow
410