xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tooling_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tooling_util.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <iterator>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_join.h"
29 #include "absl/strings/str_replace.h"
30 #include "absl/strings/str_split.h"
31 #include "re2/re2.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/lite/toco/dump_graphviz.h"
35 #include "tensorflow/lite/toco/model_flags.pb.h"
36 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
37 
38 namespace toco {
39 
40 // Find the longest common prefix of two strings.
FindLongestCommonPrefix(absl::string_view a,absl::string_view b)41 absl::string_view FindLongestCommonPrefix(absl::string_view a,
42                                           absl::string_view b) {
43   if (a.empty() || b.empty()) return absl::string_view();
44 
45   const char* pa = a.data();
46   const char* pb = b.data();
47   size_t count = 0;
48   const size_t limit = std::min(a.size(), b.size());
49   while (count < limit && *pa == *pb) {
50     ++pa;
51     ++pb;
52     ++count;
53   }
54 
55   return absl::string_view(a.data(), count);
56 }
57 
LogName(const Operator & op)58 std::string LogName(const Operator& op) {
59   const std::string& opname = HelpfulOperatorTypeName(op);
60   if (op.outputs.empty()) {
61     return toco::port::StringF("{%s operator}", opname);
62   } else {
63     return toco::port::StringF("{%s operator with output %s}", opname,
64                                op.outputs[0]);
65   }
66 }
67 
ArrayDataTypeName(ArrayDataType data_type)68 std::string ArrayDataTypeName(ArrayDataType data_type) {
69   switch (data_type) {
70     case ArrayDataType::kFloat:
71       return "float";
72     case ArrayDataType::kInt8:
73       return "int8";
74     case ArrayDataType::kUint8:
75       return "uint8";
76     case ArrayDataType::kInt16:
77       return "int16";
78     case ArrayDataType::kUint16:
79       return "uint16";
80     case ArrayDataType::kInt32:
81       return "int32";
82     case ArrayDataType::kUint32:
83       return "uint32";
84     case ArrayDataType::kInt64:
85       return "int64";
86     case ArrayDataType::kUint64:
87       return "uint64";
88     case ArrayDataType::kString:
89       return "string";
90     case ArrayDataType::kBool:
91       return "bool";
92     case ArrayDataType::kComplex64:
93       return "complex64";
94     case ArrayDataType::kNone:
95       return "None";
96     default:
97       LOG(FATAL) << "Unhandled array data type " << static_cast<int>(data_type);
98   }
99 }
100 
IsInputArray(const Model & model,const std::string & array_name)101 bool IsInputArray(const Model& model, const std::string& array_name) {
102   for (const auto& input_array : model.flags.input_arrays()) {
103     if (array_name == input_array.name()) {
104       return true;
105     }
106   }
107   return false;
108 }
109 
IsOutputArray(const Model & model,const std::string & array_name)110 bool IsOutputArray(const Model& model, const std::string& array_name) {
111   for (const auto& output_array : model.flags.output_arrays()) {
112     if (array_name == output_array) {
113       return true;
114     }
115   }
116   return false;
117 }
118 
IsArrayConsumed(const Model & model,const std::string & name)119 bool IsArrayConsumed(const Model& model, const std::string& name) {
120   if (GetOpWithInput(model, name)) {
121     return true;
122   }
123   if (IsOutputArray(model, name)) {
124     return true;
125   }
126   for (const auto& rnn_state : model.flags.rnn_states()) {
127     if (rnn_state.back_edge_source_array() == name) {
128       return true;
129     }
130   }
131   return false;
132 }
133 
CountTrueOutputs(const Model & model,const Operator & op)134 int CountTrueOutputs(const Model& model, const Operator& op) {
135   int count = 0;
136   for (const std::string& output : op.outputs) {
137     if (IsArrayConsumed(model, output)) {
138       ++count;
139     }
140   }
141   return count;
142 }
143 
CountOpsWithInput(const Model & model,const std::string & array_name)144 int CountOpsWithInput(const Model& model, const std::string& array_name) {
145   int count = 0;
146   for (const auto& op : model.operators) {
147     for (auto& input : op->inputs) {
148       if (input == array_name) {
149         count++;
150         // Breaking here is important: some graphs have ops that use the
151         // same array as more than one of their inputs, and in that case
152         // we want it counted only once.
153         break;
154       }
155     }
156   }
157   return count;
158 }
159 
DeleteArrayIfUnused(const std::string & array_name,Model * model)160 bool DeleteArrayIfUnused(const std::string& array_name, Model* model) {
161   if (IsDiscardableArray(*model, array_name) &&
162       CountOpsWithInput(*model, array_name) == 0 &&
163       GetOpWithOutput(*model, array_name) == nullptr) {
164     model->EraseArray(array_name);
165     return true;
166   }
167   return false;
168 }
169 
DeleteArrayIfUnusedOutsideOfOp(const std::string & array_name,const Operator * op,Model * model)170 bool DeleteArrayIfUnusedOutsideOfOp(const std::string& array_name,
171                                     const Operator* op, Model* model) {
172   if (!IsDiscardableArray(*model, array_name)) {
173     return false;
174   }
175   if (CountOpsWithInput(*model, array_name) > 1) {
176     return false;
177   }
178   const Operator* op_having_this_as_input = GetOpWithInput(*model, array_name);
179   if (op_having_this_as_input && op_having_this_as_input != op) {
180     return false;
181   }
182   const Operator* op_having_this_as_output =
183       GetOpWithOutput(*model, array_name);
184   if (op_having_this_as_output && op_having_this_as_output != op) {
185     return false;
186   }
187   model->EraseArray(array_name);
188   return true;
189 }
190 
DeleteOpAndArrays(Model * model,const Operator * op)191 void DeleteOpAndArrays(Model* model, const Operator* op) {
192   for (const std::string& array_name : op->inputs) {
193     DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
194   }
195   for (const std::string& array_name : op->outputs) {
196     DeleteArrayIfUnusedOutsideOfOp(array_name, op, model);
197   }
198   auto op_it = FindOp(*model, op);
199   CHECK(op_it != model->operators.end());
200   model->operators.erase(op_it);
201 }
202 
FindOpWithOutput(const Model & model,const std::string & array_name)203 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithOutput(
204     const Model& model, const std::string& array_name) {
205   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
206     for (auto& output : it->get()->outputs) {
207       if (output == array_name) {
208         return it;
209       }
210     }
211   }
212   return model.operators.end();
213 }
214 
FindOpWithOutput(Model & model,const std::string & array_name)215 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithOutput(
216     Model& model, const std::string& array_name) {
217   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
218     for (auto& output : it->get()->outputs) {
219       if (output == array_name) {
220         return it;
221       }
222     }
223   }
224   return model.operators.end();
225 }
226 
GetOpWithOutput(const Model & model,const std::string & array_name)227 Operator* GetOpWithOutput(const Model& model, const std::string& array_name) {
228   auto it = FindOpWithOutput(model, array_name);
229   return it == model.operators.end() ? nullptr : it->get();
230 }
231 
232 // GetFirstOpWithInput assumes that this finds the first op.
FindOpWithInput(const Model & model,const std::string & array_name)233 std::vector<std::unique_ptr<Operator>>::const_iterator FindOpWithInput(
234     const Model& model, const std::string& array_name) {
235   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
236     for (auto& input : it->get()->inputs) {
237       if (input == array_name) {
238         return it;
239       }
240     }
241   }
242   return model.operators.end();
243 }
244 
FindOpWithInput(Model & model,const std::string & array_name)245 std::vector<std::unique_ptr<Operator>>::iterator FindOpWithInput(
246     Model& model, const std::string& array_name) {
247   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
248     for (auto& input : it->get()->inputs) {
249       if (input == array_name) {
250         return it;
251       }
252     }
253   }
254   return model.operators.end();
255 }
256 
FindOp(const Model & model,const Operator * op)257 std::vector<std::unique_ptr<Operator>>::const_iterator FindOp(
258     const Model& model, const Operator* op) {
259   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
260     if (it->get() == op) {
261       return it;
262     }
263   }
264   return model.operators.end();
265 }
266 
FindOp(Model & model,const Operator * op)267 std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
268                                                         const Operator* op) {
269   for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
270     if (it->get() == op) {
271       return it;
272     }
273   }
274   return model.operators.end();
275 }
276 
GetOpWithInput(const Model & model,const std::string & array_name)277 Operator* GetOpWithInput(const Model& model, const std::string& array_name) {
278   auto it = FindOpWithInput(model, array_name);
279   return it == model.operators.end() ? nullptr : it->get();
280 }
281 
GetFirstOpWithInput(const Model & model,const std::string & array_name)282 Operator* GetFirstOpWithInput(const Model& model,
283                               const std::string& array_name) {
284   auto it = FindOpWithInput(model, array_name);
285   return it == model.operators.end() ? nullptr : it->get();
286 }
287 
ReplaceArrayUsage(Model * model,const std::string & old_array_name,const std::string & new_array_name)288 void ReplaceArrayUsage(Model* model, const std::string& old_array_name,
289                        const std::string& new_array_name) {
290   for (auto& op_it : model->operators) {
291     Operator* op = op_it.get();
292     for (size_t i = 0; i < op->inputs.size(); ++i) {
293       if (op->inputs[i] == old_array_name) {
294         op->inputs[i] = new_array_name;
295       }
296     }
297     for (size_t i = 0; i < op->outputs.size(); ++i) {
298       if (op->outputs[i] == old_array_name) {
299         op->outputs[i] = new_array_name;
300       }
301     }
302   }
303 }
304 
FormatArraysList(const Model & model,const std::vector<std::string> & list)305 std::string FormatArraysList(const Model& model,
306                              const std::vector<std::string>& list) {
307   if (list.empty()) {
308     return "[]";
309   }
310   std::string result = "";
311   if (list.size() > 1) {
312     result += "[ ";
313   }
314   for (std::size_t i = 0; i < list.size(); i++) {
315     if (i > 0) {
316       result += ", ";
317     }
318     result += list[i];
319   }
320   if (list.size() > 1) {
321     result += " ]";
322   }
323   return result;
324 }
325 
OperatorTypeName(OperatorType type)326 const char* OperatorTypeName(OperatorType type) {
327   switch (type) {
328 #define HANDLE_OPERATORTYPENAME_CASE(c) \
329   case OperatorType::k##c:              \
330     return #c;
331     HANDLE_OPERATORTYPENAME_CASE(Abs)
332     HANDLE_OPERATORTYPENAME_CASE(Add)
333     HANDLE_OPERATORTYPENAME_CASE(AddN)
334     HANDLE_OPERATORTYPENAME_CASE(AveragePool)
335     HANDLE_OPERATORTYPENAME_CASE(BatchMatMul)
336     HANDLE_OPERATORTYPENAME_CASE(BatchNormalization)
337     HANDLE_OPERATORTYPENAME_CASE(Conv)
338     HANDLE_OPERATORTYPENAME_CASE(Concatenation)
339     HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv)
340     HANDLE_OPERATORTYPENAME_CASE(DepthToSpace)
341     HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth)
342     HANDLE_OPERATORTYPENAME_CASE(FullyConnected)
343     HANDLE_OPERATORTYPENAME_CASE(HardSwish)
344     HANDLE_OPERATORTYPENAME_CASE(Dequantize)
345     HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
346     HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
347     HANDLE_OPERATORTYPENAME_CASE(Log)
348     HANDLE_OPERATORTYPENAME_CASE(Logistic)
349     HANDLE_OPERATORTYPENAME_CASE(LstmCell)
350     HANDLE_OPERATORTYPENAME_CASE(MaxPool)
351     HANDLE_OPERATORTYPENAME_CASE(L2Pool)
352     HANDLE_OPERATORTYPENAME_CASE(FakeQuant)
353     HANDLE_OPERATORTYPENAME_CASE(Mul)
354     HANDLE_OPERATORTYPENAME_CASE(RandomUniform)
355     HANDLE_OPERATORTYPENAME_CASE(Elu)
356     HANDLE_OPERATORTYPENAME_CASE(Relu)
357     HANDLE_OPERATORTYPENAME_CASE(Relu1)
358     HANDLE_OPERATORTYPENAME_CASE(Relu6)
359     HANDLE_OPERATORTYPENAME_CASE(PRelu)
360     HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
361     HANDLE_OPERATORTYPENAME_CASE(Softmax)
362     HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)
363     HANDLE_OPERATORTYPENAME_CASE(Div)
364     HANDLE_OPERATORTYPENAME_CASE(Tanh)
365     HANDLE_OPERATORTYPENAME_CASE(Sin)
366     HANDLE_OPERATORTYPENAME_CASE(All)
367     HANDLE_OPERATORTYPENAME_CASE(Assert)
368     HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
369     HANDLE_OPERATORTYPENAME_CASE(Fill)
370     HANDLE_OPERATORTYPENAME_CASE(FloorMod)
371     HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
372     HANDLE_OPERATORTYPENAME_CASE(Greater)
373     HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
374     HANDLE_OPERATORTYPENAME_CASE(Identity)
375     HANDLE_OPERATORTYPENAME_CASE(Less)
376     HANDLE_OPERATORTYPENAME_CASE(LessEqual)
377     HANDLE_OPERATORTYPENAME_CASE(MatMul)
378     HANDLE_OPERATORTYPENAME_CASE(ReduceMax)  //  Reduction Max
379     HANDLE_OPERATORTYPENAME_CASE(Maximum)    //  Element-wise Maximum
380     HANDLE_OPERATORTYPENAME_CASE(Merge)
381     HANDLE_OPERATORTYPENAME_CASE(ReduceMin)  //  Reduction Min
382     HANDLE_OPERATORTYPENAME_CASE(Minimum)    //  Element-wise Minimum
383     HANDLE_OPERATORTYPENAME_CASE(Neg)
384     HANDLE_OPERATORTYPENAME_CASE(OneHot)
385     HANDLE_OPERATORTYPENAME_CASE(Pack)
386     HANDLE_OPERATORTYPENAME_CASE(Pad)
387     HANDLE_OPERATORTYPENAME_CASE(PadV2)
388     HANDLE_OPERATORTYPENAME_CASE(StridedSlice)
389     HANDLE_OPERATORTYPENAME_CASE(Range)
390     HANDLE_OPERATORTYPENAME_CASE(Rank)
391     HANDLE_OPERATORTYPENAME_CASE(Reshape)
392     HANDLE_OPERATORTYPENAME_CASE(Squeeze)
393     HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
394     HANDLE_OPERATORTYPENAME_CASE(SegmentSum)
395     HANDLE_OPERATORTYPENAME_CASE(Shape)
396     HANDLE_OPERATORTYPENAME_CASE(Slice)
397     HANDLE_OPERATORTYPENAME_CASE(Split)
398     HANDLE_OPERATORTYPENAME_CASE(SplitV)
399     HANDLE_OPERATORTYPENAME_CASE(Sqrt)
400     HANDLE_OPERATORTYPENAME_CASE(Square)
401     HANDLE_OPERATORTYPENAME_CASE(Switch)
402     HANDLE_OPERATORTYPENAME_CASE(Sub)
403     HANDLE_OPERATORTYPENAME_CASE(Sum)
404     HANDLE_OPERATORTYPENAME_CASE(Tile)
405     HANDLE_OPERATORTYPENAME_CASE(Transpose)
406     HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
407     HANDLE_OPERATORTYPENAME_CASE(Concat)
408     HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
409     HANDLE_OPERATORTYPENAME_CASE(Cast)
410     HANDLE_OPERATORTYPENAME_CASE(Floor)
411     HANDLE_OPERATORTYPENAME_CASE(Ceil)
412     HANDLE_OPERATORTYPENAME_CASE(Round)
413     HANDLE_OPERATORTYPENAME_CASE(Gather)
414     HANDLE_OPERATORTYPENAME_CASE(GatherNd)
415     HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear)
416     HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND)
417     HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND)
418     HANDLE_OPERATORTYPENAME_CASE(Mean)
419     HANDLE_OPERATORTYPENAME_CASE(ReduceProd)
420     HANDLE_OPERATORTYPENAME_CASE(Svdf)
421     HANDLE_OPERATORTYPENAME_CASE(ArgMax)
422     HANDLE_OPERATORTYPENAME_CASE(ArgMin)
423     HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
424     HANDLE_OPERATORTYPENAME_CASE(Unsupported)
425     HANDLE_OPERATORTYPENAME_CASE(Exp)
426     HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
427     HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
428     HANDLE_OPERATORTYPENAME_CASE(Select)
429     HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
430     HANDLE_OPERATORTYPENAME_CASE(Equal)
431     HANDLE_OPERATORTYPENAME_CASE(NotEqual)
432     HANDLE_OPERATORTYPENAME_CASE(Pow)
433     HANDLE_OPERATORTYPENAME_CASE(Any)
434     HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
435     HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
436     HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
437     HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
438     HANDLE_OPERATORTYPENAME_CASE(Unpack)
439     HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
440     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
441     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceLstm)
442     HANDLE_OPERATORTYPENAME_CASE(BidirectionalSequenceRnn)
443     HANDLE_OPERATORTYPENAME_CASE(ResizeNearestNeighbor)
444     HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
445     HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
446     HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
447     HANDLE_OPERATORTYPENAME_CASE(Unique)
448     HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
449     HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
450     HANDLE_OPERATORTYPENAME_CASE(Cos)
451     HANDLE_OPERATORTYPENAME_CASE(Where)
452     HANDLE_OPERATORTYPENAME_CASE(ReverseSequence)
453     HANDLE_OPERATORTYPENAME_CASE(MatrixDiag)
454     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
455     HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
456     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
457     HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
458     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
459     HANDLE_OPERATORTYPENAME_CASE(ScatterNd)
460     default:
461       LOG(FATAL) << "Unhandled op type";
462 #undef HANDLE_OPERATORTYPENAME_CASE
463   }
464 }
465 
HelpfulOperatorTypeName(const Operator & op)466 std::string HelpfulOperatorTypeName(const Operator& op) {
467   if (op.type == OperatorType::kUnsupported) {
468     return toco::port::StringF(
469         "(Unsupported TensorFlow op: %s)",
470         static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
471   }
472   return OperatorTypeName(op.type);
473 }
474 
OperatorSupportsFusedActivation(OperatorType type)475 bool OperatorSupportsFusedActivation(OperatorType type) {
476   switch (type) {
477     case OperatorType::kAdd:
478     case OperatorType::kAveragePool:
479     case OperatorType::kBatchNormalization:
480     case OperatorType::kConv:
481     case OperatorType::kDepthwiseConv:
482     case OperatorType::kDiv:
483     case OperatorType::kFullyConnected:
484     case OperatorType::kL2Pool:
485     case OperatorType::kMaxPool:
486     case OperatorType::kMul:
487     case OperatorType::kSub:
488     case OperatorType::kSquaredDifference:
489       return true;
490     default:
491       return false;
492   }
493 }
494 
LogSummary(int log_level,const Model & model)495 void LogSummary(int log_level, const Model& model) {
496   VLOG(log_level) << "Operators summary (" << model.operators.size()
497                   << " operators):";
498   std::unordered_multiset<OperatorType> ops_by_type;
499   for (const auto& op : model.operators) {
500     ops_by_type.insert(op->type);
501   }
502   auto it = ops_by_type.begin();
503   while (it != ops_by_type.end()) {
504     int count = ops_by_type.count(*it);
505     VLOG(log_level) << "    " << OperatorTypeName(*it) << ": " << count;
506     std::advance(it, count);
507   }
508 }
509 
LogArray(int log_level,const Model & model,const std::string & name)510 void LogArray(int log_level, const Model& model, const std::string& name) {
511   VLOG(log_level) << "Array: " << name;
512   if (!model.HasArray(name)) {
513     VLOG(log_level) << "  DOES NOT EXIST";
514     return;
515   }
516   const auto& array = model.GetArray(name);
517   VLOG(log_level) << "  Data type: " << ArrayDataTypeName(array.data_type);
518   VLOG(log_level) << "  Final type: "
519                   << ArrayDataTypeName(array.final_data_type);
520   if (array.buffer) {
521     VLOG(log_level) << "  Constant Buffer";
522   }
523   if (array.alloc) {
524     VLOG(log_level) << "  Transient Alloc";
525   }
526   if (array.has_shape()) {
527     const Shape& array_shape = array.shape();
528     if (array_shape.dimensions_count() == 0) {
529       VLOG(log_level) << "  (Zero dimensions)";
530     } else {
531       std::string message = "  Dims: ";
532       bool first = true;
533       for (const int dim : array_shape.dims()) {
534         if (!first) {
535           message += ", ";
536         }
537         first = false;
538         toco::port::AppendF(&message, "%d", dim);
539       }
540       VLOG(log_level) << message;
541     }
542   }
543   if (array.minmax) {
544     VLOG(log_level) << "  MinMax: " << array.minmax->min << " .. "
545                     << array.minmax->max;
546   }
547   if (array.quantization_params) {
548     VLOG(log_level) << "  QuantizationParams: zero_point="
549                     << static_cast<int>(array.quantization_params->zero_point)
550                     << ", scale=" << array.quantization_params->scale;
551   }
552 }
553 
DumpGraphvizVideoFrame(const Model & model)554 void DumpGraphvizVideoFrame(const Model& model) {
555   namespace port = toco::port;
556 
557   const auto& dump_options = *GraphVizDumpOptions::singleton();
558   if (!dump_options.dump_graphviz_video) {
559     return;
560   }
561   CHECK(!dump_options.dump_graphviz.empty());
562   // TODO(benoitjacob): the static data here means that this function
563   // is stateful, not reentrant, and effectively leaks memory till exit
564   // (since dump_hashes can only grow in size). It also means that it
565   // really only is intended to be called for a single model during the
566   // process' lifetime. So it's not great design at all. The overriding
567   // design aspect here is to make the video-dumping code as unintrusive
568   // and self-contained as possible. Eventually, we'll want to have that
569   // cleaned-up, but that will require some form of general statefulness
570   // in toco (some kind of 'tooling state' data structure) that does
571   // not exist at present, and would be premature to design here just for
572   // this new video-dumping feature.
573   static int dump_id = 0;
574   static std::unordered_set<std::size_t> dump_hashes;
575   std::string graphviz_dump;
576   DumpGraphviz(model, &graphviz_dump,
577                toco::port::StringF("VIDEO frame:%05d", dump_id));
578   std::size_t hash = std::hash<std::string>{}(graphviz_dump);
579   if (!dump_hashes.count(hash)) {
580     LOG(INFO) << "DUMPING GRAPHVIZ VIDEO FRAME: " << dump_id;
581     dump_hashes.insert(hash);
582     const auto result = port::file::SetContents(
583         port::file::JoinPath(
584             dump_options.dump_graphviz,
585             toco::port::StringF("toco_video_%05d.dot", dump_id)),
586         graphviz_dump, port::file::Defaults());
587     QCHECK(result.ok()) << result.error_message();
588     dump_id++;
589   }
590 }
591 
LogDump(int log_level,const std::string & message,const Model & model)592 void LogDump(int log_level, const std::string& message, const Model& model) {
593   namespace port = toco::port;
594   const auto& dump_options = *GraphVizDumpOptions::singleton();
595 
596   DumpGraphvizVideoFrame(model);
597   if (!dump_options.dump_graphviz.empty()) {
598     std::string graphviz_dump;
599 
600     DumpGraphviz(model, &graphviz_dump, message);
601     const auto result = port::file::SetContents(
602         port::file::JoinPath(
603             dump_options.dump_graphviz,
604             absl::StrCat("toco_", absl::StrReplaceAll(message, {{" ", "_"}}),
605                          ".dot")),
606         graphviz_dump, port::file::Defaults());
607     QCHECK(result.ok()) << result.error_message();
608   }
609 
610   if (!VLOG_IS_ON(log_level)) {
611     return;
612   }
613   VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")";
614   LogSummary(log_level, model);
615   std::unordered_set<std::string> already_printed_arrays;
616   for (const auto& op : model.operators) {
617     for (const auto& input : op->inputs) {
618       if (!already_printed_arrays.count(input)) {
619         already_printed_arrays.insert(input);
620         LogArray(log_level, model, input);
621       }
622     }
623     VLOG(log_level) << HelpfulOperatorTypeName(*op) << " :";
624     VLOG(log_level) << "  " << FormatArraysList(model, op->inputs) << " -> "
625                     << FormatArraysList(model, op->outputs);
626     if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
627       VLOG(log_level) << "    (with fused activation function)";
628     }
629     for (const auto& output : op->outputs) {
630       if (!already_printed_arrays.count(output)) {
631         already_printed_arrays.insert(output);
632         LogArray(log_level, model, output);
633       }
634     }
635   }
636   VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")";
637 }
638 
639 // Note remaining raw-array extension in ProcessTensorFlowReshapeOperator().
ExtendShape(Shape * shape,int new_shape_size)640 void ExtendShape(Shape* shape, int new_shape_size) {
641   CHECK_GE(new_shape_size, shape->dimensions_count());
642   const int size_increase = new_shape_size - shape->dimensions_count();
643   auto* shape_dims = shape->mutable_dims();
644   shape_dims->insert(shape_dims->begin(), size_increase, 1);
645 }
646 
647 // TODO(b/62904716) Remove along with remaining uses.
UnextendShape(Shape * shape,int new_shape_size)648 void UnextendShape(Shape* shape, int new_shape_size) {
649   CHECK_LE(new_shape_size, shape->dimensions_count());
650   const int size_reduction = shape->dimensions_count() - new_shape_size;
651   for (int i = 0; i < size_reduction; i++) {
652     CHECK_EQ(shape->dims(i), 1);
653   }
654   std::vector<int>& shape_dims = *shape->mutable_dims();
655   shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
656 }
657 
658 // In general, zero-sized dimensions are disallowed, but there are exceptions,
659 // e.g., if the tensor data itself represents a scalar (rank 0) shape, its
660 // shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
661 // strict, and is appropriate for ops and comparisons where an empty shape
662 // doesn't make sense.
663 template <typename Dims>
CheckValidShapeDimensions(const Dims & dims)664 void CheckValidShapeDimensions(const Dims& dims) {
665   if (dims.size() == 1 && dims[0] == 0) {
666     return;
667   }
668   for (const auto& dim : dims) {
669     CHECK_GE(dim, 1);
670   }
671 }
672 
CheckValidShape(const Shape & shape)673 void CheckValidShape(const Shape& shape) {
674   CheckValidShapeDimensions(shape.dims());
675 }
676 
IsNonEmpty(const Shape & shape)677 bool IsNonEmpty(const Shape& shape) {
678   for (int i = 0; i < shape.dimensions_count(); ++i) {
679     if (shape.dims(i) < 1) return false;
680   }
681   return true;
682 }
683 
CheckNonEmptyShapeDimensions(const Shape & shape)684 void CheckNonEmptyShapeDimensions(const Shape& shape) {
685   for (int i = 0; i < shape.dimensions_count(); ++i) {
686     CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
687                                  << ". shape = " << ShapeToString(shape);
688   }
689 }
690 
ShapesAgreeUpToBroadcasting(const Shape & shape0,const Shape & shape1)691 bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
692   CheckNonEmptyShapeDimensions(shape0);
693   CheckNonEmptyShapeDimensions(shape1);
694 
695   const Shape* longer = &shape0;
696   const Shape* shorter = &shape1;
697   if (shape1.dimensions_count() > shape0.dimensions_count()) {
698     longer = &shape1;
699     shorter = &shape0;
700   }
701 
702   // Walk dimensions back to front until we run out of dimensions in the shorter
703   // shape.
704   int longer_index = longer->dimensions_count() - 1;
705   int shorter_index = shorter->dimensions_count() - 1;
706   while (shorter_index >= 0) {
707     const int d_long = longer->dims(longer_index);
708     const int d_short = shorter->dims(shorter_index);
709     // Broadcasting fails if the dimensions are different *and* neither is 1.
710     if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) {
711       return false;
712     }
713     longer_index--;
714     shorter_index--;
715   }
716   return true;
717 }
718 
ShapesAgreeUpToExtending(const Shape & shape0,const Shape & shape1)719 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
720   CheckNonEmptyShapeDimensions(shape0);
721   CheckNonEmptyShapeDimensions(shape1);
722 
723   const Shape* longer = &shape0;
724   const Shape* shorter = &shape1;
725   if (shape1.dimensions_count() > shape0.dimensions_count()) {
726     longer = &shape1;
727     shorter = &shape0;
728   }
729 
730   // Walk dimensions back to front until we run out of dimensions in the shorter
731   // shape.
732   int longer_index = longer->dimensions_count() - 1;
733   int shorter_index = shorter->dimensions_count() - 1;
734   while (shorter_index >= 0) {
735     const int d_long = longer->dims(longer_index);
736     const int d_short = shorter->dims(shorter_index);
737     // Extending fails if the dimensions are different.
738     if (d_long != d_short) {
739       return false;
740     }
741     longer_index--;
742     shorter_index--;
743   }
744 
745   // The remaining dimensions in the longer shape must be 1.
746   while (longer_index >= 0) {
747     const int d_long = longer->dims(longer_index);
748     if (d_long != 1) {
749       return false;
750     }
751     longer_index--;
752   }
753 
754   return true;
755 }
756 
RequiredBufferSizeForShape(const Shape & shape)757 int RequiredBufferSizeForShape(const Shape& shape) {
758   CheckValidShape(shape);
759   int max_offset = 1;
760   for (const auto& dim : shape.dims()) {
761     max_offset *= dim;
762   }
763   return max_offset;
764 }
765 
IsConstantParameterArray(const Model & model,const std::string & name)766 bool IsConstantParameterArray(const Model& model, const std::string& name) {
767   if (!model.HasArray(name)) {
768     return false;
769   }
770 
771   return !!model.GetArray(name).buffer;
772 }
773 
774 namespace {
775 template <ArrayDataType A>
CompareArrayBuffers(const Array & lhs_array,const Array & rhs_array)776 bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) {
777   CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match";
778   CHECK(lhs_array.buffer) << "LHS must be constant";
779   CHECK(rhs_array.buffer) << "RHS must be constant";
780   const auto& lhs_data = lhs_array.GetBuffer<A>().data;
781   const auto& rhs_data = rhs_array.GetBuffer<A>().data;
782   CHECK_EQ(lhs_data.size(), rhs_data.size())
783       << "Buffer sizes must match in element count";
784   for (int i = 0; i < lhs_data.size(); ++i) {
785     if (lhs_data[i] != rhs_data[i]) {
786       return false;
787     }
788   }
789   return true;
790 }
791 
HaveSameMinMax(const Array & lhs_array,const Array & rhs_array)792 bool HaveSameMinMax(const Array& lhs_array, const Array& rhs_array) {
793   if (lhs_array.minmax || rhs_array.minmax) {
794     if (!lhs_array.minmax || !rhs_array.minmax) {
795       return false;
796     }
797     if (!(*lhs_array.minmax == *rhs_array.minmax)) {
798       return false;
799     }
800   }
801   return true;
802 }
803 
HaveSameQuantizationParams(const Array & lhs_array,const Array & rhs_array)804 bool HaveSameQuantizationParams(const Array& lhs_array,
805                                 const Array& rhs_array) {
806   if (lhs_array.quantization_params || rhs_array.quantization_params) {
807     if (!lhs_array.quantization_params || !rhs_array.quantization_params) {
808       return false;
809     }
810     if (!(*lhs_array.quantization_params == *rhs_array.quantization_params)) {
811       return false;
812     }
813   }
814   return true;
815 }
816 
817 }  // namespace
818 
CompareConstantArrays(const Array & lhs_array,const Array & rhs_array)819 bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) {
820   bool attrs_equal = lhs_array.shape() == rhs_array.shape() &&
821                      lhs_array.data_type == rhs_array.data_type &&
822                      lhs_array.final_data_type == rhs_array.final_data_type &&
823                      HaveSameMinMax(lhs_array, rhs_array) &&
824                      HaveSameQuantizationParams(lhs_array, rhs_array) &&
825                      lhs_array.narrow_range == rhs_array.narrow_range;
826   if (!attrs_equal) {
827     return false;
828   }
829   switch (lhs_array.data_type) {
830     case ArrayDataType::kBool:
831       return CompareArrayBuffers<ArrayDataType::kBool>(lhs_array, rhs_array);
832     case ArrayDataType::kFloat:
833       return CompareArrayBuffers<ArrayDataType::kFloat>(lhs_array, rhs_array);
834     case ArrayDataType::kInt8:
835       return CompareArrayBuffers<ArrayDataType::kInt8>(lhs_array, rhs_array);
836     case ArrayDataType::kUint8:
837       return CompareArrayBuffers<ArrayDataType::kUint8>(lhs_array, rhs_array);
838     case ArrayDataType::kInt16:
839       return CompareArrayBuffers<ArrayDataType::kInt16>(lhs_array, rhs_array);
840     case ArrayDataType::kUint16:
841       return CompareArrayBuffers<ArrayDataType::kUint16>(lhs_array, rhs_array);
842     case ArrayDataType::kInt32:
843       return CompareArrayBuffers<ArrayDataType::kInt32>(lhs_array, rhs_array);
844     case ArrayDataType::kUint32:
845       return CompareArrayBuffers<ArrayDataType::kUint32>(lhs_array, rhs_array);
846     case ArrayDataType::kInt64:
847       return CompareArrayBuffers<ArrayDataType::kInt64>(lhs_array, rhs_array);
848     case ArrayDataType::kUint64:
849       return CompareArrayBuffers<ArrayDataType::kUint64>(lhs_array, rhs_array);
850     case ArrayDataType::kString:
851       return CompareArrayBuffers<ArrayDataType::kString>(lhs_array, rhs_array);
852     case ArrayDataType::kComplex64:
853       return CompareArrayBuffers<ArrayDataType::kComplex64>(lhs_array,
854                                                             rhs_array);
855     default:
856       LOG(FATAL) << "Unsupported data type: "
857                  << ArrayDataTypeName(lhs_array.data_type);
858       return false;
859   }
860 }
861 
862 namespace {
863 // Take an array name, which may be something like "name:3_5" and make it
864 // acceptable as a TF node name, say "name_3_5";
SanitizeNameForTFNode(const std::string & array_name)865 std::string SanitizeNameForTFNode(const std::string& array_name) {
866   auto node_name = array_name;
867   std::replace(node_name.begin(), node_name.end(), ':', '_');
868   return node_name;
869 }
870 
CheckInputArraysAreNotOutputArrays(const ModelFlags & model_flags)871 void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) {
872   for (const auto& input_array : model_flags.input_arrays()) {
873     for (const std::string& output_array : model_flags.output_arrays()) {
874       QCHECK_NE(input_array.name(), output_array)
875           << "The array " << output_array
876           << " is listed in both --input_arrays and --output_arrays.";
877     }
878   }
879 }
880 
IsAsciiPrintable(const std::string & name)881 bool IsAsciiPrintable(const std::string& name) {
882   for (char c : name) {
883     if (!absl::ascii_isprint(c)) {
884       return false;
885     }
886   }
887   return true;
888 }
889 
DumpAscii(const std::string & name)890 std::string DumpAscii(const std::string& name) {
891   std::string result;
892   port::AppendF(&result, "ASCII | Hex\n");
893   port::AppendF(&result, "------+----\n");
894   for (char c : name) {
895     if (absl::ascii_isprint(c)) {
896       port::AppendF(&result, "%c     | %x\n", c, c);
897     } else {
898       port::AppendF(&result, "      | %x   Not ASCII printable!\n", c);
899     }
900   }
901   return result;
902 }
903 
CheckNonAsciiIOArrays(const ModelFlags & model_flags)904 void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
905   if (model_flags.allow_nonascii_arrays()) {
906     return;
907   }
908   for (const auto& input_array : model_flags.input_arrays()) {
909     QCHECK(IsAsciiPrintable(input_array.name()))
910         << "Non-ASCII-printable character found in --input_arrays: "
911         << input_array.name()
912         << ". Pass --allow_nonascii_arrays to allow that. "
913         << "Here is a dump of the string:\n\n"
914         << DumpAscii(input_array.name());
915   }
916   for (const std::string& output_array : model_flags.output_arrays()) {
917     QCHECK(IsAsciiPrintable(output_array))
918         << "Non-ASCII-printable character found in --output_arrays: "
919         << output_array << ". Pass --allow_nonascii_arrays to allow that. "
920         << "Here is a dump of the string:\n\n"
921         << DumpAscii(output_array);
922   }
923 }
924 
CheckNonExistentIOArrays(const Model & model)925 void CheckNonExistentIOArrays(const Model& model) {
926   // "non-existent" is interpreted in the stronger sense of
927   // "not actually produced/consumed by an op".
928   // Rationale: we have to artificially fix up TensorFlow graphs by creating
929   // any array that it refers to, so just checking that arrays exist isn't
930   // sufficient. The real invariant here is whether arrays are produced/consumed
931   // by something.
932   if (model.flags.allow_nonexistent_arrays()) {
933     return;
934   }
935   static constexpr char general_comment[] =
936       "Is it a typo? This should not happen. If you trigger this error "
937       "please send a bug report (with code to reproduce this error), to the "
938       "TensorFlow Lite team.";
939   for (const std::string& output_array : model.flags.output_arrays()) {
940     if (IsConstantParameterArray(model, output_array)) {
941       continue;  // It is OK to request that a constant be an output.
942     }
943     QCHECK(GetOpWithOutput(model, output_array))
944         << "Specified output array \"" << output_array
945         << "\" is not produced by any op in this graph. " << general_comment;
946   }
947   for (const auto& rnn_state : model.flags.rnn_states()) {
948     if (!rnn_state.discardable()) {
949       // Check that all RNN states are consumed
950       QCHECK(GetOpWithInput(model, rnn_state.state_array()))
951           << "Specified RNN state \"" << rnn_state.state_array()
952           << "\" is not consumed by any op in this graph. " << general_comment;
953       // Check that all RNN back-edge source arrays are produced
954       QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
955           << "Specified RNN back-edge source array \""
956           << rnn_state.back_edge_source_array()
957           << "\" is not produced by any op in this graph. " << general_comment;
958     }
959   }
960 }
961 
962 }  // namespace
963 
CheckNoMissingArray(const Model & model)964 void CheckNoMissingArray(const Model& model) {
965   for (const auto& op : model.operators) {
966     for (const auto& input : op->inputs) {
967       CHECK(model.HasArray(input) || model.optional_arrays.count(input))
968           << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
969     }
970     for (const auto& output : op->outputs) {
971       CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
972     }
973   }
974   CheckNonExistentIOArrays(model);
975 }
976 
FixNoMissingArray(Model * model)977 void FixNoMissingArray(Model* model) {
978   for (const auto& op : model->operators) {
979     for (const auto& input : op->inputs) {
980       if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
981         model->GetOrCreateArray(input);
982       }
983     }
984     for (const auto& output : op->outputs) {
985       if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
986         model->GetOrCreateArray(output);
987       }
988     }
989   }
990   if (model->flags.allow_nonexistent_arrays()) {
991     for (const std::string& output_array : model->flags.output_arrays()) {
992       model->GetOrCreateArray(output_array);
993     }
994     for (const auto& rnn_state : model->flags.rnn_states()) {
995       model->GetOrCreateArray(rnn_state.state_array());
996       model->GetOrCreateArray(rnn_state.back_edge_source_array());
997     }
998   }
999 }
1000 
CheckNoOrphanedArray(const Model & model)1001 void CheckNoOrphanedArray(const Model& model) {
1002   std::unordered_set<std::string> arrays_without_known_use;
1003   for (const auto& array : model.GetArrayMap()) {
1004     if (IsDiscardableArray(model, array.first)) {
1005       arrays_without_known_use.insert(array.first);
1006     }
1007   }
1008   for (const auto& op : model.operators) {
1009     for (const auto& input : op->inputs) {
1010       arrays_without_known_use.erase(input);
1011     }
1012     for (const auto& output : op->outputs) {
1013       arrays_without_known_use.erase(output);
1014     }
1015   }
1016   for (const auto& rnn_state : model.flags.rnn_states()) {
1017     arrays_without_known_use.erase(rnn_state.state_array());
1018     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1019   }
1020   if (!arrays_without_known_use.empty()) {
1021     for (const auto& array : arrays_without_known_use) {
1022       LOG(INFO) << "Error: Orphaned array: " << array;
1023     }
1024   }
1025   CHECK(arrays_without_known_use.empty());
1026 }
1027 
FixNoOrphanedArray(Model * model)1028 void FixNoOrphanedArray(Model* model) {
1029   std::unordered_set<std::string> arrays_without_known_use;
1030   for (const auto& array : model->GetArrayMap()) {
1031     arrays_without_known_use.insert(array.first);
1032   }
1033   for (const auto& op : model->operators) {
1034     for (const auto& input : op->inputs) {
1035       arrays_without_known_use.erase(input);
1036     }
1037     for (const auto& output : op->outputs) {
1038       arrays_without_known_use.erase(output);
1039     }
1040   }
1041   for (const auto& rnn_state : model->flags.rnn_states()) {
1042     arrays_without_known_use.erase(rnn_state.state_array());
1043     arrays_without_known_use.erase(rnn_state.back_edge_source_array());
1044   }
1045   for (const auto& array : arrays_without_known_use) {
1046     if (IsDiscardableArray(*model, array)) {
1047       model->EraseArray(array);
1048     }
1049   }
1050 }
1051 
1052 // Apply checks to arrays individually (for-each fashion).
1053 //
1054 // Check consistency of array fields, check name.
CheckEachArray(const Model & model)1055 void CheckEachArray(const Model& model) {
1056   for (const auto& array_entry : model.GetArrayMap()) {
1057     const auto& array = array_entry.second;
1058     // It's OK to have a buffer or an alloc, but not both.
1059     // (Since allocs are for transient arrays without a buffer).
1060     CHECK(!array->buffer || !array->alloc) << "Tensor: " << array_entry.first;
1061     if (array->buffer) {
1062       // If there is a buffer, its type should be consistent with data_type.
1063       CHECK(array->buffer->type == array->data_type)
1064           << "Tensor: " << array_entry.first;
1065       // The presence of a fixed buffer should imply the presence of a fixed
1066       // shape.
1067       CHECK(array->has_shape()) << array_entry.first;
1068       // Constant buffer should has a valid shape.
1069       CheckValidShape(array->shape());
1070       // The shape flat-size should agree with the buffer length.
1071       CHECK_EQ(array->buffer->Length(),
1072                RequiredBufferSizeForShape(array->shape()))
1073           << "Tensor: " << array_entry.first;
1074     }
1075 
1076     // Check name.  Either "name_with_suffix_8", "name_with_port:3", but not
1077     // "name_with_both:3_8".
1078     const std::string& name = array_entry.first;
1079     auto colon_pos = name.find_first_of(':');
1080     if (colon_pos != std::string::npos) {
1081       CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
1082                std::string::npos)
1083           << "Array '" << name << "' has non-digit characters after colon.";
1084     }
1085     CHECK_GT(colon_pos, 0) << "Array '" << name
1086                            << "' must not start with a colon.";
1087   }
1088 }
1089 
CheckOperatorOrdering(const Model & model)1090 void CheckOperatorOrdering(const Model& model) {
1091   std::unordered_set<std::string> arrays_behind_us;
1092   for (const auto& array_entry : model.GetArrayMap()) {
1093     if (!GetOpWithOutput(model, array_entry.first)) {
1094       arrays_behind_us.insert(array_entry.first);
1095     }
1096   }
1097   arrays_behind_us.insert(model.optional_arrays.begin(),
1098                           model.optional_arrays.end());
1099   for (const auto& op : model.operators) {
1100     for (const auto& input : op->inputs) {
1101       if (!IsConstantParameterArray(model, input)) {
1102         CHECK(arrays_behind_us.count(input));
1103       }
1104     }
1105     for (const auto& output : op->outputs) {
1106       CHECK(!arrays_behind_us.count(output));
1107       arrays_behind_us.insert(output);
1108     }
1109   }
1110   for (const std::string& output_array : model.flags.output_arrays()) {
1111     CHECK(arrays_behind_us.count(output_array));
1112   }
1113 }
1114 
FixOperatorOrdering(Model * model)1115 void FixOperatorOrdering(Model* model) {
1116   std::unordered_set<std::string> arrays_behind_us;
1117   for (const auto& array_entry : model->GetArrayMap()) {
1118     if (!GetOpWithOutput(*model, array_entry.first)) {
1119       arrays_behind_us.insert(array_entry.first);
1120     }
1121   }
1122   arrays_behind_us.insert(model->optional_arrays.begin(),
1123                           model->optional_arrays.end());
1124   std::vector<std::unique_ptr<Operator>> old_operators;
1125   std::swap(old_operators, model->operators);
1126   std::set<std::size_t> remaining;
1127   for (std::size_t i = 0; i < old_operators.size(); i++) {
1128     remaining.insert(i);
1129   }
1130   std::unordered_map<std::string, std::string> reason_why_leftover;
1131   while (true) {
1132     bool inserted_something = false;
1133     for (const auto& i : remaining) {
1134       bool can_insert = true;
1135       auto& op = old_operators[i];
1136       CHECK(op);
1137       for (const auto& input : op->inputs) {
1138         if (!IsConstantParameterArray(*model, input) &&
1139             !arrays_behind_us.count(input)) {
1140           for (const std::string& output : op->outputs) {
1141             reason_why_leftover[output] = input;
1142           }
1143           can_insert = false;
1144           break;
1145         }
1146       }
1147       if (can_insert) {
1148         model->operators.emplace_back(nullptr);
1149         for (const auto& output : op->outputs) {
1150           arrays_behind_us.insert(output);
1151         }
1152         std::swap(op, model->operators.back());
1153         remaining.erase(i);
1154         inserted_something = true;
1155         break;
1156       }
1157     }
1158     if (!inserted_something) {
1159       break;
1160     }
1161   }
1162   if (!remaining.empty()) {
1163     LOG(ERROR)
1164         << "No viable ordering of operators was found. "
1165         << "Here is a 'backtrace' of at least one part of the graph that is "
1166         << "problematic. It starts with the first operator that has as "
1167         << "problematic input array, and then walks back the graph to "
1168         << "the operator that produced that input array, etc., until we find "
1169         << "the root cause:";
1170     LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT";
1171     LOG(ERROR) << "Here is the first-encountered operator with a bad input: ";
1172     const Operator* bad_op = old_operators[*remaining.begin()].get();
1173     std::unordered_set<std::string> bad_inputs_already_traced;
1174     // The following while(true) loop should always end with a LOG(FATAL).
1175     while (true) {
1176       LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : "
1177                  << FormatArraysList(*model, bad_op->inputs) << " -> "
1178                  << FormatArraysList(*model, bad_op->outputs);
1179       bool found_bad_output = false;
1180       std::string bad_output;
1181       for (const std::string& output : bad_op->outputs) {
1182         if (reason_why_leftover.count(output)) {
1183           found_bad_output = true;
1184           bad_output = output;
1185           break;
1186         }
1187       }
1188       CHECK(found_bad_output);
1189       const std::string& bad_input = reason_why_leftover[bad_output];
1190       LOG(ERROR) << "The bad input here is: " << bad_input;
1191       if (bad_inputs_already_traced.count(bad_input)) {
1192         LOG(FATAL)
1193             << "Cycle found! We already encountered that "
1194             << "input array, " << bad_input << ", earlier in the "
1195             << "above trace! We expect graphs to be acyclic, even "
1196             << "RNNs. Let us know if some graph actually needs to have "
1197             << "cycles, but first, please check if it really is "
1198             << "an *inference* graph. *Training* graphs are out-of-scope "
1199             << "for toco.";
1200       }
1201       bad_inputs_already_traced.insert(bad_input);
1202       bad_op = nullptr;
1203       for (const auto& i : remaining) {
1204         const Operator* op = old_operators[i].get();
1205         for (const std::string& output : op->outputs) {
1206           if (bad_input == output) {
1207             bad_op = op;
1208             break;
1209           }
1210         }
1211         if (bad_op) {
1212           break;
1213         }
1214       }
1215       if (!bad_op) {
1216         LOG(ERROR) << "And that's the root cause: "
1217                    << "that array, " << bad_input << ", isn't produced by any "
1218                    << "operator, or provided in any other way.";
1219         LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT";
1220         LOG(FATAL) << "(The above was a multi-line fatal error)";
1221       }
1222       LOG(ERROR) << "And that array is the output of the following operator:";
1223     }
1224   }
1225   CHECK(remaining.empty())
1226       << "Should never get here! In case of bad graph, "
1227       << "the above code should have generated a FATAL error already!";
1228 }
1229 
CheckInvariants(const Model & model)1230 void CheckInvariants(const Model& model) {
1231   CheckInputArraysAreNotOutputArrays(model.flags);
1232   CheckNonAsciiIOArrays(model.flags);
1233   CheckNoMissingArray(model);
1234   CheckNoOrphanedArray(model);
1235   CheckEachArray(model);
1236   CheckOperatorOrdering(model);
1237 }
1238 
CheckCountInRange(const::toco::ModelFlags::ModelCheck & model_check,const int count,const std::string & count_description)1239 void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check,
1240                        const int count, const std::string& count_description) {
1241   if (model_check.count_min() >= 0) {
1242     CHECK_GE(count, model_check.count_min())
1243         << "Mismatch in " << count_description << ": count  was " << count
1244         << ", but the specified "
1245         << (model_check.count_max() > model_check.count_min() ? "minimum"
1246                                                               : "value")
1247         << " was " << model_check.count_min() << ".";
1248   }
1249   if (model_check.count_max() > model_check.count_min()) {
1250     CHECK_LE(count, model_check.count_max())
1251         << "Mismatch in " << count_description << ": count  was " << count
1252         << ", but the specified maximum was " << model_check.count_max() << ".";
1253   }
1254 }
1255 
CheckModelCounts(const Model & model)1256 void CheckModelCounts(const Model& model) {
1257   std::unordered_multiset<OperatorType> ops_by_type;
1258   std::unordered_map<std::string, OperatorType> op_type_by_name;
1259   if (model.flags.model_checks_size() == 0) {
1260     return;
1261   }
1262 
1263   for (const auto& op : model.operators) {
1264     ops_by_type.insert(op->type);
1265     op_type_by_name[OperatorTypeName(op->type)] = op->type;
1266   }
1267   for (const auto& model_check : model.flags.model_checks()) {
1268     std::string count_type = model_check.count_type();
1269     if (count_type == "None") {
1270       continue;
1271     } else if (count_type == "Arrays") {
1272       CheckCountInRange(model_check, model.GetArrayMap().size(),
1273                         "count of arrays");
1274     } else if (count_type == "Total") {
1275       CheckCountInRange(model_check, model.operators.size(),
1276                         "count of all operator instances");
1277     } else {
1278       // The check type is not itself checked against the set of valid
1279       // operators, mainly because the enum set cannot be iterated in C++.
1280       const int found_count =
1281           op_type_by_name.count(count_type) > 0
1282               ? ops_by_type.count(op_type_by_name[count_type])
1283               : 0;
1284       CheckCountInRange(model_check, found_count,
1285                         "count of instances of " + count_type + " operator");
1286     }
1287   }
1288 }
1289 
FixEdgeArrays(Model * model)1290 void FixEdgeArrays(Model* model) {
1291   for (const std::string& output_array_name : model->flags.output_arrays()) {
1292     if (!GetOpWithOutput(*model, output_array_name)) {
1293       // Output has no operator producing it. Change that by inserting a copy.
1294       LOG(WARNING) << "Fixing constant output array " << output_array_name
1295                    << " by inserting a copy. This is not optimal.";
1296       std::string intermediate_array_name =
1297           AvailableArrayName(*model, output_array_name + "_copy");
1298       CloneArray(model, output_array_name, intermediate_array_name);
1299       InsertCopyOperator(model, intermediate_array_name, output_array_name);
1300     }
1301   }
1302 }
1303 
DedupeConstantArrays(Model * model,size_t min_size)1304 void DedupeConstantArrays(Model* model, size_t min_size) {
1305   // Walk all 0..N and compare with the remaining n+1..N.
1306   // This lets us avoid N^2 comparisons and erase duplicate arrays while
1307   // iterating.
1308   const auto& array_map = model->GetArrayMap();
1309   for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end();
1310        ++lhs_array_it) {
1311     const auto& lhs_array_name = lhs_array_it->first;
1312     const auto& lhs_array = *lhs_array_it->second;
1313     if (!IsConstantParameterArray(*model, lhs_array_name)) {
1314       // Not a constant array; skip.
1315       continue;
1316     }
1317     ArrayDataType final_data_type =
1318         lhs_array.final_data_type != ArrayDataType::kNone
1319             ? lhs_array.final_data_type
1320             : lhs_array.data_type;
1321     // Ignore small arrays, don't check string arrays because it is not possible
1322     // to estimate its size.
1323     if (final_data_type != ArrayDataType::kString) {
1324       size_t array_byte_size =
1325           lhs_array.buffer->Length() * ElementSize(final_data_type);
1326       if (array_byte_size < min_size) {
1327         // Too small; skip.
1328         continue;
1329       }
1330     }
1331 
1332     auto next_lhs_array_it = lhs_array_it;
1333     ++next_lhs_array_it;
1334     for (auto rhs_array_it = next_lhs_array_it;
1335          rhs_array_it != array_map.end();) {
1336       const auto& rhs_array_name = rhs_array_it->first;
1337       const auto& rhs_array = *rhs_array_it->second;
1338       ++rhs_array_it;
1339       if (!IsConstantParameterArray(*model, rhs_array_name)) {
1340         // Not a constant array; skip.
1341         continue;
1342       }
1343       if (!IsDiscardableArray(*model, rhs_array_name)) {
1344         // Can't remove the array as it's not discardable (such as an IO edge).
1345         continue;
1346       }
1347       if (!CompareConstantArrays(lhs_array, rhs_array)) {
1348         // Arrays aren't equal; skip.
1349         continue;
1350       }
1351 
1352       // Arrays can be deduped!
1353       VLOG(1) << "Deduplicating arrays; using " << lhs_array_name
1354               << " in place of " << rhs_array_name;
1355       ReplaceArrayUsage(model, rhs_array_name, lhs_array_name);
1356       // Note: rhs_array_it above is already incremented so this is safe.
1357       model->EraseArray(rhs_array_name);
1358     }
1359   }
1360 }
1361 
1362 namespace {
CopyArrayAttribs(const Array & source_array,Array * target_array)1363 void CopyArrayAttribs(const Array& source_array, Array* target_array) {
1364   target_array->data_type = source_array.data_type;
1365   target_array->final_data_type = source_array.final_data_type;
1366   if (source_array.has_shape()) {
1367     target_array->copy_shape(source_array.shape());
1368   }
1369 
1370   if (source_array.minmax) {
1371     target_array->GetOrCreateMinMax() = source_array.GetMinMax();
1372   } else {
1373     target_array->minmax.reset();
1374   }
1375 
1376   if (source_array.quantization_params) {
1377     target_array->GetOrCreateQuantizationParams() =
1378         source_array.GetQuantizationParams();
1379   } else {
1380     target_array->quantization_params.reset();
1381   }
1382 }
1383 }  // namespace
1384 
InsertCopyOperator(Model * model,const std::string & source_array_name,const std::string & target_array_name)1385 void InsertCopyOperator(Model* model, const std::string& source_array_name,
1386                         const std::string& target_array_name) {
1387   // Reshape to the same size. This should be a no-op.
1388   const Array& source_array = model->GetArray(source_array_name);
1389   std::vector<int> shape = source_array.shape().dims();
1390 
1391   // Drop constant data from the target array as the copy will be done at
1392   // runtime.
1393   Array& target_array = model->GetOrCreateArray(target_array_name);
1394   target_array.buffer.reset();
1395   CopyArrayAttribs(source_array, &target_array);
1396 
1397   // Insert copy operator.
1398   auto* copy_op = new TensorFlowReshapeOperator;
1399   copy_op->inputs = {
1400       source_array_name,
1401       CreateInt32Array(
1402           model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
1403           shape)};
1404   copy_op->outputs = {target_array_name};
1405   if (target_array.has_shape()) {
1406     copy_op->shape = target_array.shape().dims();
1407   }
1408   model->operators.emplace_back(copy_op);
1409 }
1410 
CloneArray(Model * model,const std::string & source_array_name,const std::string & target_array_name)1411 void CloneArray(Model* model, const std::string& source_array_name,
1412                 const std::string& target_array_name) {
1413   CHECK(!model->HasArray(target_array_name));
1414   const Array& source_array = model->GetArray(source_array_name);
1415   Array& target_array = model->GetOrCreateArray(target_array_name);
1416   CopyArrayAttribs(source_array, &target_array);
1417 
1418   if (!source_array.buffer) {
1419     return;
1420   }
1421 
1422   switch (source_array.data_type) {
1423     case ArrayDataType::kBool:
1424       CopyArrayBuffer<ArrayDataType::kBool>(source_array, &target_array);
1425       break;
1426     case ArrayDataType::kFloat:
1427       CopyArrayBuffer<ArrayDataType::kFloat>(source_array, &target_array);
1428       break;
1429     case ArrayDataType::kInt8:
1430       CopyArrayBuffer<ArrayDataType::kInt8>(source_array, &target_array);
1431       break;
1432     case ArrayDataType::kUint8:
1433       CopyArrayBuffer<ArrayDataType::kUint8>(source_array, &target_array);
1434       break;
1435     case ArrayDataType::kInt16:
1436       CopyArrayBuffer<ArrayDataType::kInt16>(source_array, &target_array);
1437       break;
1438     case ArrayDataType::kUint16:
1439       CopyArrayBuffer<ArrayDataType::kUint16>(source_array, &target_array);
1440       break;
1441     case ArrayDataType::kInt32:
1442       CopyArrayBuffer<ArrayDataType::kInt32>(source_array, &target_array);
1443       break;
1444     case ArrayDataType::kUint32:
1445       CopyArrayBuffer<ArrayDataType::kUint32>(source_array, &target_array);
1446       break;
1447     case ArrayDataType::kInt64:
1448       CopyArrayBuffer<ArrayDataType::kInt64>(source_array, &target_array);
1449       break;
1450     case ArrayDataType::kUint64:
1451       CopyArrayBuffer<ArrayDataType::kUint64>(source_array, &target_array);
1452       break;
1453     case ArrayDataType::kString:
1454       CopyArrayBuffer<ArrayDataType::kString>(source_array, &target_array);
1455       break;
1456     case ArrayDataType::kComplex64:
1457       CopyArrayBuffer<ArrayDataType::kComplex64>(source_array, &target_array);
1458       break;
1459     default:
1460       LOG(FATAL) << "Unsupported data type: "
1461                  << ArrayDataTypeName(source_array.data_type);
1462       return;
1463   }
1464 }
1465 
MakeArrayDims(int num_dims,int batch,int height,int width,int depth,std::vector<int> * out_dims)1466 void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
1467                    std::vector<int>* out_dims) {
1468   CHECK(out_dims->empty());
1469   if (num_dims == 0) {
1470     return;
1471   } else if (num_dims == 1) {
1472     CHECK_EQ(batch, 1);
1473     *out_dims = {depth};
1474   } else if (num_dims == 2) {
1475     *out_dims = {batch, depth};
1476   } else if (num_dims == 3) {
1477     CHECK_EQ(batch, 1);
1478     *out_dims = {height, width, depth};
1479   } else if (num_dims == 4) {
1480     *out_dims = {batch, height, width, depth};
1481   } else {
1482     LOG(FATAL) << "Should not get here: " << num_dims;
1483   }
1484 }
1485 
CreateOrCheckRnnStateArray(const std::string & name,int size,int state_num_dims,Model * model)1486 void CreateOrCheckRnnStateArray(const std::string& name, int size,
1487                                 int state_num_dims, Model* model) {
1488   int batch = 1;
1489   int num_dims = -1;
1490   if (state_num_dims > 0) {
1491     num_dims = state_num_dims;
1492   } else {
1493     // state_num_dims is not given. We will infer it from an input tensor.
1494     for (const auto& input_array : model->flags.input_arrays()) {
1495       // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find
1496       // a better match by name.
1497       if (input_array.name() == name || num_dims == -1) {
1498         num_dims = input_array.shape().dims_size();
1499         if (num_dims > 0) {
1500           batch = input_array.shape().dims(0);
1501         }
1502       }
1503     }
1504   }
1505   Array& array = model->GetOrCreateArray(name);
1506   if (array.has_shape()) {
1507     num_dims = array.shape().dimensions_count();
1508   }
1509   if (!array.has_shape() && num_dims >= 0) {
1510     Shape* shape = array.mutable_shape();
1511     std::vector<int> dims;
1512     MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
1513     *shape->mutable_dims() = dims;
1514   }
1515 }
1516 
ResolveModelFlags(const ModelFlags & model_flags,Model * model)1517 void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
1518   // Merge info about input_arrays from model_flags into model->flags
1519   for (const auto& specified_input_array : model_flags.input_arrays()) {
1520     toco::InputArray* dst_input_array = nullptr;
1521     for (int i = 0; i < model->flags.input_arrays_size(); i++) {
1522       toco::InputArray* candidate_dst_input_array =
1523           model->flags.mutable_input_arrays(i);
1524       if (candidate_dst_input_array->name() == specified_input_array.name()) {
1525         // specified_input_array from model_flags maps to dst_input_array
1526         // in model->flags
1527         dst_input_array = candidate_dst_input_array;
1528         break;
1529       }
1530     }
1531     if (!dst_input_array) {
1532       // Specified_input_array from model_flags is not found in model->flags.
1533       // Match a name-less specified input array when there can be no ambiguity
1534       // as there is only 1 input array.
1535       if (model->flags.input_arrays_size() == 1 &&
1536           model_flags.input_arrays_size() == 1 &&
1537           !specified_input_array.has_name()) {
1538         dst_input_array = model->flags.mutable_input_arrays(0);
1539       }
1540     }
1541     if (!dst_input_array) {
1542       // Still no match, so create a new input array to copy
1543       // specified_input_array into.
1544       dst_input_array = model->flags.add_input_arrays();
1545       dst_input_array->set_name(specified_input_array.name());
1546     }
1547 
1548 #define RESOLVE_MODEL_FLAG(field_name)                                       \
1549   if (specified_input_array.has_##field_name()) {                            \
1550     if (dst_input_array->has_##field_name()) {                               \
1551       QCHECK_EQ(dst_input_array->field_name(),                               \
1552                 specified_input_array.field_name())                          \
1553           << "For input array '" << dst_input_array->name() << "', "         \
1554           << "specified " #field_name " flag with value: "                   \
1555           << specified_input_array.field_name()                              \
1556           << " does not agree with already defined " #field_name             \
1557              " of this model, with value: "                                  \
1558           << specified_input_array.field_name();                             \
1559     } else {                                                                 \
1560       dst_input_array->set_##field_name(specified_input_array.field_name()); \
1561     }                                                                        \
1562   }
1563     RESOLVE_MODEL_FLAG(std_value);
1564     RESOLVE_MODEL_FLAG(mean_value);
1565 #undef RESOLVE_MODEL_FLAG
1566 
1567     if (specified_input_array.has_shape()) {
1568       if (dst_input_array->has_shape()) {
1569         QCHECK_EQ(specified_input_array.shape().dims_size(),
1570                   dst_input_array->shape().dims_size())
1571             << "For input array '" << specified_input_array.name() << "', "
1572             << "size of specified input shape flag with size: "
1573             << specified_input_array.shape().dims_size()
1574             << " does not agree with already defined input shape"
1575                " of this model, with size: "
1576             << dst_input_array->shape().dims_size();
1577         // We treat the first dimension as a special case, since it is often
1578         // a batch size and the input_shape flag is effectively overriding
1579         // the model.
1580         for (int i = 1; i < specified_input_array.shape().dims_size(); i++) {
1581           QCHECK_EQ(specified_input_array.shape().dims(i),
1582                     dst_input_array->shape().dims(i))
1583               << "At dimension number " << i << " of input array "
1584               << specified_input_array.name() << ", the specified shape's "
1585               << "dimension flag with dimension: "
1586               << specified_input_array.shape().dims(i)
1587               << " does not agree with already defined shape"
1588               << " of this model, with dimension: "
1589               << dst_input_array->shape().dims(i);
1590         }
1591       } else {
1592         *dst_input_array->mutable_shape() = specified_input_array.shape();
1593       }
1594     }
1595 
1596     if (specified_input_array.has_data_type()) {
1597       QCHECK(!dst_input_array->has_data_type());
1598       dst_input_array->set_data_type(specified_input_array.data_type());
1599     }
1600   }
1601 
1602   if (model_flags.output_arrays_size() > 0) {
1603     model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays());
1604   }
1605 
1606 #define RESOLVE_MODEL_FLAG(name)                                           \
1607   if (model_flags.has_##name()) {                                          \
1608     if (model->flags.has_##name()) {                                       \
1609       QCHECK_EQ(model_flags.name(), model->flags.name())                   \
1610           << "Specified " #name " flag with value: " << model_flags.name() \
1611           << " does not agree with already defined " #name                 \
1612              " of this model, with value: "                                \
1613           << model->flags.name();                                          \
1614     } else {                                                               \
1615       model->flags.set_##name(model_flags.name());                         \
1616     }                                                                      \
1617   }
1618 
1619   RESOLVE_MODEL_FLAG(variable_batch)
1620 
1621 #undef RESOLVE_MODEL_FLAG
1622 
1623   if (!model_flags.rnn_states().empty()) {
1624     model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states());
1625   }
1626 
1627   if (model->flags.model_checks_size() == 0) {
1628     model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks());
1629   }
1630 
1631   QCHECK_GT(model->flags.output_arrays_size(), 0)
1632       << "This model does not define output arrays, so a "
1633          "--output_arrays flag must be given on the command-line.";
1634 
1635   for (auto& input_array_proto : *model->flags.mutable_input_arrays()) {
1636     auto& input_array = model->GetOrCreateArray(input_array_proto.name());
1637     if (input_array_proto.has_data_type()) {
1638       const ArrayDataType specified_type =
1639           ConvertIODataTypeToArrayDataType(input_array_proto.data_type());
1640       QCHECK(specified_type != ArrayDataType::kNone);
1641       if (input_array.data_type != ArrayDataType::kNone) {
1642         QCHECK(specified_type == input_array.data_type)
1643             << "For input array " << input_array_proto.name()
1644             << " the specified input data type "
1645             << IODataType_Name(input_array_proto.data_type())
1646             << " conflicts with the existing type.";
1647       }
1648       input_array.data_type = specified_type;
1649     }
1650 
1651     if (input_array.data_type == ArrayDataType::kNone) {
1652       // We start out with a float input array;
1653       // that may get replaced by a uint8 array later, by
1654       // MakeInitialDequantizeOp.
1655       input_array.data_type = ArrayDataType::kFloat;
1656     }
1657 
1658     // Compare/merge the model->flags describing the input_shape with
1659     // the actual input array's shape.
1660     if (!input_array.has_shape()) {
1661       if (input_array_proto.has_shape()) {
1662         auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
1663         CheckValidShapeDimensions(input_array_proto.shape().dims());
1664         for (const auto& dim : input_array_proto.shape().dims()) {
1665           input_array_dims.push_back(dim);
1666         }
1667       }
1668     } else {
1669       if (input_array_proto.has_shape()) {
1670         // If an input shape was specified on the flags ensure that it matches
1671         // the actual shape in the model.
1672         const auto& input_array_dims =
1673             *input_array.mutable_shape()->mutable_dims();
1674         CHECK_EQ(input_array_dims.size(),
1675                  input_array_proto.shape().dims_size());
1676         for (int i = 0; i < input_array_dims.size(); i++) {
1677           CHECK_EQ(input_array_dims[i], input_array_proto.shape().dims(i));
1678         }
1679       } else {
1680         for (int i = 0; i < input_array.shape().dimensions_count(); i++) {
1681           input_array_proto.mutable_shape()->add_dims(
1682               input_array.shape().dims(i));
1683         }
1684       }
1685     }
1686 
1687     const float mean_value = input_array_proto.mean_value();
1688     const float std_value = input_array_proto.std_value();
1689     MinMax input_minmax;
1690     float qmin = 0, qmax = 255;
1691     if (input_array.data_type == ArrayDataType::kInt16) {
1692       qmin = -32768;
1693       qmax = 32767;
1694     }
1695     input_minmax.min = (qmin - mean_value) / std_value;
1696     input_minmax.max = (qmax - mean_value) / std_value;
1697     if (!input_array.minmax) {
1698       input_array.GetOrCreateMinMax() = input_minmax;
1699     }
1700   }
1701 
1702   // Creation of the RNN state arrays
1703   for (const auto& rnn_state : model->flags.rnn_states()) {
1704     CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
1705                                rnn_state.num_dims(), model);
1706   }
1707 
1708   model->flags.set_change_concat_input_ranges(
1709       model_flags.change_concat_input_ranges());
1710   model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
1711   model->flags.set_allow_nonexistent_arrays(
1712       model_flags.allow_nonexistent_arrays());
1713 
1714   CHECK(!model->flags.has_arrays_extra_info());
1715   *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
1716 }
1717 
CheckIsReadyForQuantization(const Model & model)1718 void CheckIsReadyForQuantization(const Model& model) {
1719   for (const auto& op : model.operators) {
1720     for (const auto& input : op->inputs) {
1721       const auto& input_array = model.GetArray(input);
1722       if (input_array.data_type != ArrayDataType::kFloat) {
1723         // The array is not floats, no quantization needed.
1724         continue;
1725       }
1726       if (input_array.minmax) {
1727         // The array has minmax, we're good.
1728         continue;
1729       }
1730       if (input_array.buffer) {
1731         // The array has a constant buffer, so we can
1732         // fall back to computing the minmax from actual array entries
1733         // (with a WARNING about possible accuracy implications).
1734         continue;
1735       }
1736       LOG(FATAL)
1737           << "Array " << input << ", which is an input to the "
1738           << HelpfulOperatorTypeName(*op) << " operator producing the output "
1739           << "array " << op->outputs[0] << ", is lacking min/max data, "
1740           << "which is necessary for quantization. If accuracy matters, either "
1741           << "target a non-quantized output format, or run quantized training "
1742           << "with your model from a floating point checkpoint to change the "
1743           << "input graph to contain min/max information. If you don't care "
1744           << "about accuracy, you can pass --default_ranges_min= and "
1745           << "--default_ranges_max= for easy experimentation.";
1746     }
1747   }
1748 }
1749 
ElementSize(ArrayDataType data_type)1750 int ElementSize(ArrayDataType data_type) {
1751   switch (data_type) {
1752     case ArrayDataType::kBool:
1753       return sizeof(bool);
1754     case ArrayDataType::kFloat:
1755       return 4;
1756     case ArrayDataType::kInt8:
1757       return 1;
1758     case ArrayDataType::kUint8:
1759       return 1;
1760     case ArrayDataType::kInt16:
1761       return 2;
1762     case ArrayDataType::kUint16:
1763       return 2;
1764     case ArrayDataType::kInt32:
1765       return 4;
1766     case ArrayDataType::kUint32:
1767       return 4;
1768     case ArrayDataType::kInt64:
1769       return 8;
1770     case ArrayDataType::kUint64:
1771       return 8;
1772     case ArrayDataType::kComplex64:
1773       return 8;
1774     case ArrayDataType::kComplex128:
1775       return 16;
1776     case ArrayDataType::kFloat64:
1777       return 8;
1778 
1779     // Usually not critical limitation because strings are only input and/or
1780     // output.
1781     case ArrayDataType::kString:
1782       LOG(FATAL) << "Transient arrays with strings are not supported yet";
1783       return 0;
1784     default:
1785       LOG(FATAL) << "Unknown data_type = " << static_cast<int>(data_type);
1786       return 0;
1787   }
1788 }
1789 
DropMinMax(Model * model,const std::string & array_name)1790 void DropMinMax(Model* model, const std::string& array_name) {
1791   auto& array = model->GetArray(array_name);
1792   if (!!array.minmax) {
1793     LOG(WARNING) << "Dropping MinMax information in array " << array_name
1794                  << ". Expect inaccuracy in quantized inference.";
1795     array.minmax = nullptr;
1796   }
1797 }
1798 
IsAllocatableTransientArray(const Model & model,const std::string & array_name)1799 bool IsAllocatableTransientArray(const Model& model,
1800                                  const std::string& array_name) {
1801   // Optional array is not transient
1802   if (model.IsOptionalArray(array_name)) return false;
1803   // The model's input and output arrays are externally allocated.
1804   // They are not transient arrays.
1805   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
1806     return false;
1807   }
1808   const auto& array = &model.GetArray(array_name);
1809   // An array with a constant buffer isn't a transient array.
1810   if (!!array->buffer) {
1811     return false;
1812   }
1813   // An array without shape isn't allocatable.
1814   if (!array->has_shape()) {
1815     return false;
1816   }
1817 
1818   // The size of string tensors is rarely known ahead of time, so all transient
1819   // tensors of this type will need to be dynamically allocated.
1820   if (array->final_data_type == ArrayDataType::kString ||
1821       array->data_type == ArrayDataType::kString) {
1822     return false;
1823   }
1824 
1825   return true;
1826 }
1827 
AvailableArrayName(const Model & model,const std::string & name)1828 std::string AvailableArrayName(const Model& model, const std::string& name) {
1829   std::string sanitized_name = SanitizeNameForTFNode(name);
1830   if (!model.HasArray(sanitized_name) &&
1831       !model.IsOptionalArray(sanitized_name)) {
1832     return sanitized_name;
1833   }
1834   const int kNumSuffixesToTry = 1000;
1835   for (int i = 0; i < kNumSuffixesToTry; i++) {
1836     const std::string& name_with_suffix =
1837         toco::port::StringF("%s_%d", sanitized_name, i);
1838     if (!model.HasArray(name_with_suffix) &&
1839         !model.IsOptionalArray(name_with_suffix)) {
1840       return name_with_suffix;
1841     }
1842   }
1843   LOG(FATAL) << "Could not find an available array name starting with "
1844              << sanitized_name << ". Tried " << kNumSuffixesToTry
1845              << " suffixes, all were taken!";
1846   return "";
1847 }
1848 
ShapeToString(const Shape & shape)1849 std::string ShapeToString(const Shape& shape) {
1850   if (shape.dimensions_count() == 0) {
1851     return "[]";
1852   }
1853 
1854   return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]");
1855 }
1856 
PrintArrayShape(Model * model,const std::string & name)1857 void PrintArrayShape(Model* model, const std::string& name) {
1858   if (!model->GetArray(name).has_shape()) {
1859     LOG(INFO) << name << " has no shape";
1860     return;
1861   }
1862   LOG(INFO) << name
1863             << " has shape: " << ShapeToString(model->GetArray(name).shape());
1864 }
1865 
IsArrayFullyConnectedWeights(const Model & model,const std::string & name)1866 bool IsArrayFullyConnectedWeights(const Model& model, const std::string& name) {
1867   bool is_fc_weights = false;
1868   bool is_something_else = false;
1869   for (const auto& op : model.operators) {
1870     for (int input_index = 0; input_index < op->inputs.size(); input_index++) {
1871       if (op->inputs[input_index] == name) {
1872         if (op->type == OperatorType::kFullyConnected && input_index == 1) {
1873           is_fc_weights = true;
1874         } else {
1875           is_something_else = true;
1876         }
1877       }
1878     }
1879   }
1880   CHECK(!(is_fc_weights && is_something_else));
1881   return is_fc_weights;
1882 }
1883 
CreateInt32Array(Model * model,const std::string & param_name,const std::vector<int> & value)1884 std::string CreateInt32Array(Model* model, const std::string& param_name,
1885                              const std::vector<int>& value) {
1886   auto param_array_name = AvailableArrayName(*model, param_name);
1887   auto& param_array = model->GetOrCreateArray(param_array_name);
1888   param_array.mutable_shape()->ReplaceDims({static_cast<int>(value.size())});
1889   param_array.data_type = ArrayDataType::kInt32;
1890   auto& param_array_data =
1891       param_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
1892   param_array_data.resize(RequiredBufferSizeForShape(param_array.shape()));
1893   for (int i = 0; i < value.size(); ++i) {
1894     param_array_data[i] = value[i];
1895   }
1896   return param_array_name;
1897 }
1898 
EstimateArithmeticOpsCount(const Model & model,const Operator & op,int64_t * result)1899 bool EstimateArithmeticOpsCount(const Model& model, const Operator& op,
1900                                 int64_t* result) {
1901   switch (op.type) {
1902     case OperatorType::kFullyConnected:
1903     case OperatorType::kConv:
1904     case OperatorType::kDepthwiseConv: {
1905       const auto& output_array = model.GetArray(op.outputs[0]);
1906       const auto& weights_array = model.GetArray(op.inputs[1]);
1907       if (!output_array.has_shape() || !weights_array.has_shape()) {
1908         return false;
1909       }
1910       int64_t cols = 1;
1911       for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) {
1912         cols *= output_array.shape().dims(i);
1913       }
1914       const int64_t cost_per_col =
1915           2 * RequiredBufferSizeForShape(weights_array.shape());
1916       *result = cost_per_col * cols;
1917       if (op.inputs.size() > 2) {
1918         // There is a bias vector. One more op per output value.
1919         *result += RequiredBufferSizeForShape(output_array.shape());
1920       }
1921       break;
1922     }
1923     case OperatorType::kTransposeConv: {
1924       const auto& input_array = model.GetArray(op.inputs[2]);
1925       const auto& weights_array = model.GetArray(op.inputs[1]);
1926       if (!input_array.has_shape() || !weights_array.has_shape()) {
1927         return false;
1928       }
1929       const Shape& input = input_array.shape();
1930       const Shape& weights = weights_array.shape();
1931       // Compute op count from the seven nested loops of
1932       // tflite::reference_ops::TransposeConv():
1933       *result = 2 * input.dims(0) * input.dims(1) * input.dims(2) *
1934                 input.dims(3) * weights.dims(1) * weights.dims(2) *
1935                 weights.dims(0);
1936       // Note that tflite::optimized_ops::TransposeConv() uses an im2col matrix
1937       // and has a higher op count, by a factor of (output_height*output_width)
1938       // vs. (input_height*input_width). Yet it generally performs better
1939       // because of coherent memory access. (At least for 2x2 striding. But not
1940       // likely for all cases.)
1941       break;
1942     }
1943     case OperatorType::kAdd:
1944     case OperatorType::kSub:
1945     case OperatorType::kMul: {
1946       const auto& output_array = model.GetArray(op.outputs[0]);
1947       if (!output_array.has_shape()) {
1948         return false;
1949       }
1950       *result = RequiredBufferSizeForShape(output_array.shape());
1951       break;
1952     }
1953     case OperatorType::kAddN: {
1954       const auto& output_array = model.GetArray(op.outputs[0]);
1955       if (!output_array.has_shape()) {
1956         return false;
1957       }
1958       // AddN cost is roughly the same cost as N-1 Adds.
1959       const int64_t num_adds = op.inputs.size() - 1;
1960       *result = num_adds * RequiredBufferSizeForShape(output_array.shape());
1961       break;
1962     }
1963     case OperatorType::kLogistic:
1964     case OperatorType::kSoftmax:
1965     case OperatorType::kLogSoftmax:
1966     case OperatorType::kTanh: {
1967       const auto& output_array = model.GetArray(op.outputs[0]);
1968       if (!output_array.has_shape()) {
1969         return false;
1970       }
1971       // As a very rough ballpark, the cost of evaluating a math function
1972       // such as tanh or logistic is about 32 multiplications, and about as
1973       // many additions/subtractions. (Just a power-of-two order-of-magnitude
1974       // from looking at actual implementations that we use in runtime/ code).
1975       *result = 64 * RequiredBufferSizeForShape(output_array.shape());
1976       break;
1977     }
1978     case OperatorType::kMaxPool: {
1979       const auto& maxpool = *static_cast<const MaxPoolOperator*>(&op);
1980       const auto& output_array = model.GetArray(op.outputs[0]);
1981       if (!output_array.has_shape()) {
1982         return false;
1983       }
1984       *result = RequiredBufferSizeForShape(output_array.shape()) *
1985                 maxpool.kheight * maxpool.kwidth;
1986       break;
1987     }
1988     case OperatorType::kAveragePool: {
1989       const auto& avgpool = *static_cast<const AveragePoolOperator*>(&op);
1990       const auto& output_array = model.GetArray(op.outputs[0]);
1991       if (!output_array.has_shape()) {
1992         return false;
1993       }
1994       *result = RequiredBufferSizeForShape(output_array.shape()) *
1995                 avgpool.kheight * avgpool.kwidth;
1996       break;
1997     }
1998     case OperatorType::kL2Pool: {
1999       const auto* maxpool = static_cast<const MaxPoolOperator*>(&op);
2000       const auto& output_array = model.GetArray(op.outputs[0]);
2001       if (!output_array.has_shape()) {
2002         return false;
2003       }
2004       // The sum of squares requires (kheight*kwidth) multiply-adds,
2005       // and then there is the sqrt which we ballpark at 32 ops.
2006       const int64_t cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32;
2007       *result = RequiredBufferSizeForShape(output_array.shape()) * cost_per_val;
2008       break;
2009     }
2010     case OperatorType::kL2Normalization: {
2011       const auto& output_array = model.GetArray(op.outputs[0]);
2012       if (!output_array.has_shape()) {
2013         return false;
2014       }
2015       // Computing the squared L2 norm is N multiply-adds so 2N ops,
2016       // then the single inverse-sqrt is negligible, then we multiply each
2017       // value by the resulting multiplier, so an extra N ops. count 3N ops.
2018       *result = 3 * RequiredBufferSizeForShape(output_array.shape());
2019       break;
2020     }
2021     default:
2022       *result = 0;
2023       break;
2024   }
2025   return true;
2026 }
2027 
EstimateArithmeticOpsCount(const Model & model,int64_t * result)2028 bool EstimateArithmeticOpsCount(const Model& model, int64_t* result) {
2029   int64_t total = 0;
2030   for (const auto& op : model.operators) {
2031     int64_t num_ops;
2032     if (!EstimateArithmeticOpsCount(model, *op, &num_ops)) {
2033       return false;
2034     }
2035     total += num_ops;
2036   }
2037   *result = total;
2038   return true;
2039 }
2040 
FormattedNumber(int64_t x)2041 std::string FormattedNumber(int64_t x) {
2042   const int64_t million = 1000000;
2043   const int64_t billion = 1000000000;
2044   if (x < 10000) {
2045     return toco::port::StringF("%d ", x);
2046   } else if (x < billion) {
2047     return toco::port::StringF("%.3f M", static_cast<double>(x) / million);
2048   } else {
2049     return toco::port::StringF("%.3f G", static_cast<double>(x) / billion);
2050   }
2051 }
2052 
GetShuffleShape(AxesOrder input_axes_order,AxesOrder output_axes_order,std::vector<int> * shuffle)2053 void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
2054                      std::vector<int>* shuffle) {
2055   CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order));
2056   shuffle->resize(4);
2057   for (int i = 0; i < 4; i++) {
2058     (*shuffle)[i] = i;
2059   }
2060   if (input_axes_order == output_axes_order) {
2061     // nothing to do
2062   } else if (AxesCount(input_axes_order) == 2) {
2063     shuffle->resize(2);
2064     (*shuffle)[0] = 1;
2065     (*shuffle)[1] = 0;
2066   } else if (input_axes_order == AxesOrder::kOHWI &&
2067              output_axes_order == AxesOrder::kHWIO) {
2068     // 3210 <- 3210
2069     // HWIO <- OHWI
2070     *shuffle = {1, 2, 3, 0};
2071   } else if (input_axes_order == AxesOrder::kHWIO &&
2072              output_axes_order == AxesOrder::kOHWI) {
2073     // 3210 <- 3210
2074     // OHWI <- HWIO
2075     *shuffle = {3, 0, 1, 2};
2076   } else if (input_axes_order == AxesOrder::kOHWI &&
2077              output_axes_order == AxesOrder::kHWOI) {
2078     *shuffle = {1, 2, 0, 3};
2079   } else {
2080     LOG(FATAL) << "Bad shuffle";
2081   }
2082 }
2083 
ExtendShuffle(const std::vector<int> & input_shuffle,int newdim,std::vector<int> * extended_shuffle)2084 void ExtendShuffle(const std::vector<int>& input_shuffle, int newdim,
2085                    std::vector<int>* extended_shuffle) {
2086   *extended_shuffle = input_shuffle;
2087   CHECK(newdim >= input_shuffle.size());
2088   const int pad_size = newdim - input_shuffle.size();
2089   extended_shuffle->resize(newdim);
2090   for (int i = 0; i < pad_size; i++) {
2091     (*extended_shuffle)[i] = i;
2092   }
2093   for (int i = pad_size; i < newdim; i++) {
2094     (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size;
2095   }
2096 }
2097 
ShuffleDims(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,Shape * output_shape)2098 void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order,
2099                  AxesOrder output_axes_order, Shape* output_shape) {
2100   if (input_axes_order == AxesOrder::kHWIM &&
2101       output_axes_order == AxesOrder::k1HWO) {
2102     // This special case isn't just a permutation, the IM pair of dims get
2103     // merged into the 3 dim, so we have to special-case it.
2104     *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1),
2105                            input_shape.dims(3) * input_shape.dims(2)});
2106   } else {
2107     std::vector<int> shuffle;
2108     GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2109     std::vector<int>* output_dims = output_shape->mutable_dims();
2110     output_dims->resize(input_shape.dimensions_count());
2111     for (int i = 0; i < input_shape.dimensions_count(); i++) {
2112       (*output_dims)[i] = input_shape.dims(shuffle[i]);
2113     }
2114   }
2115 }
2116 
2117 template <typename T>
ShuffleArrayTemplate(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const T * input_data,T * output_data)2118 void ShuffleArrayTemplate(const Shape& input_shape, AxesOrder input_axes_order,
2119                           AxesOrder output_axes_order,
2120                           const Shape& output_shape, const T* input_data,
2121                           T* output_data) {
2122   if (input_axes_order == AxesOrder::kHWIM &&
2123       output_axes_order == AxesOrder::k1HWO) {
2124     // This special case isn't just a permutation, the IM pair of dims get
2125     // merged into the O dim, so we have to special-case it. Fortunately,
2126     // as far as array shuffling is concerned, it's just the identity
2127     // transformation.
2128     memcpy(output_data, input_data,
2129            RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0]));
2130     return;
2131   }
2132   CHECK(input_shape.dimensions_count() == output_shape.dimensions_count());
2133   const int dim = input_shape.dimensions_count();
2134   CHECK_LE(dim, 4);
2135   std::vector<int> shuffle;
2136   GetShuffleShape(input_axes_order, output_axes_order, &shuffle);
2137   CHECK(shuffle.size() >= dim);
2138   for (int i = 0; i < dim; i++) {
2139     CHECK(shuffle[i] >= 0 && shuffle[i] < dim);
2140     CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i));
2141   }
2142   Shape extended_input_shape = input_shape;
2143   ExtendShape(&extended_input_shape, 4);
2144   Shape extended_output_shape = output_shape;
2145   ExtendShape(&extended_output_shape, 4);
2146   std::vector<int> extended_shuffle;
2147   ExtendShuffle(shuffle, 4, &extended_shuffle);
2148 
2149   const std::vector<int>& extended_input_dims = extended_input_shape.dims();
2150   const std::vector<int>& extended_output_dims = extended_output_shape.dims();
2151 
2152   // TODO(starka): Rework to handle different numbers of dimensions.
2153   int input_strides[4];
2154   input_strides[3] = 1;
2155   input_strides[2] = extended_input_dims[3];
2156   input_strides[1] = input_strides[2] * extended_input_dims[2];
2157   input_strides[0] = input_strides[1] * extended_input_dims[1];
2158   const int input_stride_0 = input_strides[extended_shuffle[3]];
2159   const int input_stride_1 = input_strides[extended_shuffle[2]];
2160   const int input_stride_2 = input_strides[extended_shuffle[1]];
2161   const int input_stride_3 = input_strides[extended_shuffle[0]];
2162 
2163   const int output_size_0 = extended_output_dims[3];
2164   const int output_size_1 = extended_output_dims[2];
2165   const int output_size_2 = extended_output_dims[1];
2166   const int output_size_3 = extended_output_dims[0];
2167   const int output_stride_0 = 1;
2168   const int output_stride_1 = output_size_0;
2169   const int output_stride_2 = output_stride_1 * output_size_1;
2170   const int output_stride_3 = output_stride_2 * output_size_2;
2171 
2172   for (int i3 = 0; i3 < output_size_3; i3++) {
2173     const T* const input_ptr_3 = input_data + i3 * input_stride_3;
2174     T* const output_ptr_3 = output_data + i3 * output_stride_3;
2175     for (int i2 = 0; i2 < output_size_2; i2++) {
2176       const T* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2;
2177       T* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2;
2178       for (int i1 = 0; i1 < output_size_1; i1++) {
2179         const T* input_ptr = input_ptr_2 + i1 * input_stride_1;
2180         T* output_ptr = output_ptr_2 + i1 * output_stride_1;
2181         T* const output_ptr_end = output_ptr + output_size_0 * output_stride_0;
2182         while (output_ptr != output_ptr_end) {
2183           *output_ptr = *input_ptr;
2184           input_ptr += input_stride_0;
2185           output_ptr += output_stride_0;
2186         }
2187       }
2188     }
2189   }
2190 }
2191 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const uint8 * input_data,uint8 * output_data)2192 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2193                   AxesOrder output_axes_order, const Shape& output_shape,
2194                   const uint8* input_data, uint8* output_data) {
2195   ShuffleArrayTemplate<uint8>(input_shape, input_axes_order, output_axes_order,
2196                               output_shape, input_data, output_data);
2197 }
2198 
ShuffleArray(const Shape & input_shape,AxesOrder input_axes_order,AxesOrder output_axes_order,const Shape & output_shape,const float * input_data,float * output_data)2199 void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order,
2200                   AxesOrder output_axes_order, const Shape& output_shape,
2201                   const float* input_data, float* output_data) {
2202   ShuffleArrayTemplate<float>(input_shape, input_axes_order, output_axes_order,
2203                               output_shape, input_data, output_data);
2204 }
2205 
AxesCount(AxesOrder axes_order)2206 int AxesCount(AxesOrder axes_order) {
2207   switch (axes_order) {
2208     case AxesOrder::kOneAxis:
2209       return 1;
2210     case AxesOrder::kRC:
2211       return 2;
2212     case AxesOrder::kCR:
2213       return 2;
2214     case AxesOrder::kHWIO:
2215       return 4;
2216     case AxesOrder::kOHWI:
2217       return 4;
2218     case AxesOrder::kHWIM:
2219       return 4;
2220     case AxesOrder::k1HWO:
2221       return 4;
2222     case AxesOrder::kNHWC:
2223       return 4;
2224     case AxesOrder::kHWOI:
2225       return 4;
2226     default:
2227       LOG(FATAL) << "Bad AxesOrder";
2228       return 0;
2229   }
2230 }
2231 
IsDiscardableArray(const Model & model,const std::string & array_name)2232 bool IsDiscardableArray(const Model& model, const std::string& array_name) {
2233   if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
2234     return false;
2235   }
2236   for (const auto& rnn_state : model.flags.rnn_states()) {
2237     if (!rnn_state.discardable()) {
2238       if (array_name == rnn_state.state_array()) {
2239         return false;
2240       }
2241       if (array_name == rnn_state.back_edge_source_array()) {
2242         return false;
2243       }
2244     }
2245   }
2246   return true;
2247 }
2248 
ReshapeIsEquivalentToTranspose(const Model & model,const TensorFlowReshapeOperator * op,bool allow_extra_unary_dims)2249 bool ReshapeIsEquivalentToTranspose(const Model& model,
2250                                     const TensorFlowReshapeOperator* op,
2251                                     bool allow_extra_unary_dims) {
2252   CHECK(!op->shape.empty());
2253   CHECK(model.HasArray(op->inputs[0]));
2254   CHECK(model.HasArray(op->outputs[0]));
2255 
2256   const auto& input_array = model.GetArray(op->inputs[0]);
2257   const auto& output_array = model.GetArray(op->outputs[0]);
2258 
2259   CHECK(input_array.has_shape());
2260   CHECK(output_array.has_shape());
2261 
2262   std::vector<int> in_shape = input_array.shape().dims();
2263   std::vector<int> out_shape = output_array.shape().dims();
2264 
2265   // If the reshape changes the number of dimensions so it cannot be interpreted
2266   // as a transpose.
2267   if (!allow_extra_unary_dims && in_shape.size() != out_shape.size()) {
2268     return false;
2269   }
2270 
2271   in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1),
2272                  in_shape.end());
2273   out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1),
2274                   out_shape.end());
2275   return in_shape == out_shape;
2276 }
2277 
CheckFinalDataTypesSatisfied(const Model & model)2278 void CheckFinalDataTypesSatisfied(const Model& model) {
2279   for (const auto& array_entry : model.GetArrayMap()) {
2280     const auto& array = *array_entry.second;
2281     if (array.data_type == ArrayDataType::kBool) {
2282       // Boolean values are never quantized.
2283       continue;
2284     }
2285 
2286     // If the final data type is int16, the data type may be float, for example
2287     // after dequantization.
2288     if (array.final_data_type != ArrayDataType::kNone &&
2289         array.final_data_type != ArrayDataType::kInt16) {
2290       CHECK(array.data_type == array.final_data_type)
2291           << "Array \"" << array_entry.first
2292           << "\" has mis-matching actual and final data types (data_type="
2293           << ArrayDataTypeName(array.data_type)
2294           << ", final_data_type=" << ArrayDataTypeName(array.final_data_type)
2295           << ").";
2296     }
2297   }
2298 }
2299 
ConvertIODataTypeToArrayDataType(IODataType type)2300 ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
2301   switch (type) {
2302     case FLOAT:
2303       return ArrayDataType::kFloat;
2304     case UINT8:
2305     case QUANTIZED_UINT8:
2306       return ArrayDataType::kUint8;
2307     case INT8:
2308     case QUANTIZED_INT8:
2309       return ArrayDataType::kInt8;
2310     case INT16:
2311     case QUANTIZED_INT16:
2312       return ArrayDataType::kInt16;
2313     case UINT16:
2314       return ArrayDataType::kUint16;
2315     case INT32:
2316       return ArrayDataType::kInt32;
2317     case UINT32:
2318       return ArrayDataType::kUint32;
2319     case INT64:
2320       return ArrayDataType::kInt64;
2321     case UINT64:
2322       return ArrayDataType::kUint64;
2323     case BOOL:
2324       return ArrayDataType::kBool;
2325     case STRING:
2326       return ArrayDataType::kString;
2327     case COMPLEX64:
2328       return ArrayDataType::kComplex64;
2329     case COMPLEX128:
2330       return ArrayDataType::kComplex128;
2331     case FLOAT16:
2332       return ArrayDataType::kFloat16;
2333     case FLOAT64:
2334       return ArrayDataType::kFloat64;
2335     case RESOURCE:
2336     case VARIANT:
2337     default:
2338       return ArrayDataType::kNone;
2339   }
2340 }
2341 
FinishBuildingRNNStates(Model * model)2342 void FinishBuildingRNNStates(Model* model) {
2343   for (const auto& rnn_state : model->flags.rnn_states()) {
2344     if (!model->HasArray(rnn_state.back_edge_source_array()) ||
2345         !model->HasArray(rnn_state.state_array())) {
2346       CHECK(model->HasArray(rnn_state.back_edge_source_array()));
2347       CHECK(model->HasArray(rnn_state.state_array()));
2348       continue;
2349     }
2350     const auto& src_array = model->GetArray(rnn_state.back_edge_source_array());
2351     auto& dst_array = model->GetArray(rnn_state.state_array());
2352     if (src_array.data_type == ArrayDataType::kNone &&
2353         dst_array.data_type == ArrayDataType::kNone) {
2354       dst_array.data_type = ArrayDataType::kFloat;
2355     }
2356   }
2357 }
2358 
2359 // Returns the array names that match the ArraysExtraInfo's name and
2360 // name_regexp. The regexp match is for a full match.
ScanArrayNames(const Model & model,const toco::ArraysExtraInfo_Entry & entry)2361 std::unordered_set<std::string> ScanArrayNames(
2362     const Model& model, const toco::ArraysExtraInfo_Entry& entry) {
2363   std::unordered_set<std::string> matches;
2364   if (model.HasArray(entry.name())) {
2365     matches.insert(entry.name());
2366   }
2367   if (!entry.name_regexp().empty()) {
2368     const auto& arrays = model.GetArrayMap();
2369     const RE2 name_regexp = {entry.name_regexp()};
2370     for (auto it = arrays.begin(); it != arrays.end(); ++it) {
2371       if (RE2::FullMatch(it->first, name_regexp)) {
2372         matches.insert(it->first);
2373       }
2374     }
2375   }
2376   return matches;
2377 }
2378 
UseArraysExtraInfo(Model * model,bool quantize_output)2379 void UseArraysExtraInfo(Model* model, bool quantize_output) {
2380   for (const auto& entry : model->flags.arrays_extra_info().entries()) {
2381     const auto matches = ScanArrayNames(*model, entry);
2382     if (matches.empty()) {
2383       LOG(ERROR) << "arrays_extra_info_file: No matching arrays found for "
2384                  << (entry.has_name() ? entry.name() : "")
2385                  << (entry.has_name_regexp() ? entry.name_regexp() : "");
2386       continue;
2387     }
2388     for (const auto& matched_name : matches) {
2389       auto& array = model->GetArray(matched_name);
2390       if (entry.has_min() || entry.has_max()) {
2391         CHECK_EQ(entry.has_min(), entry.has_max());
2392         auto& minmax = array.GetOrCreateMinMax();
2393         minmax.min = entry.min();
2394         minmax.max = entry.max();
2395       }
2396       if (entry.has_data_type() && quantize_output) {
2397         array.final_data_type =
2398             ConvertIODataTypeToArrayDataType(entry.data_type());
2399       }
2400       if (entry.has_shape()) {
2401         array.clear_shape();
2402         // Make sure to create the shape even if there are no dims, to
2403         // correctly record 0-D shapes.
2404         array.mutable_shape();
2405         for (const auto& dim : entry.shape().dims()) {
2406           array.mutable_shape()->mutable_dims()->push_back(dim);
2407         }
2408       }
2409       if (entry.has_constant_float_value()) {
2410         CHECK(array.has_shape());
2411         if (array.data_type == ArrayDataType::kFloat) {
2412           auto& data = array.GetMutableBuffer<ArrayDataType::kFloat>().data;
2413           data.resize(RequiredBufferSizeForShape(array.shape()));
2414           for (float& f : data) {
2415             f = entry.constant_float_value();
2416           }
2417         }
2418       }
2419     }
2420   }
2421 }
2422 
UndoWeightsShuffling(Model * model)2423 void UndoWeightsShuffling(Model* model) {
2424   for (const auto& op : model->operators) {
2425     if (op->type != toco::OperatorType::kFullyConnected) {
2426       continue;
2427     }
2428     const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
2429     if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
2430       continue;
2431     }
2432     const std::string& weights_name = fc_op.inputs[1];
2433     QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
2434     auto& weights_array = model->GetArray(weights_name);
2435     QCHECK(weights_array.data_type == ArrayDataType::kUint8);
2436     auto& weights_data =
2437         weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
2438     const auto& weights_shape = weights_array.shape();
2439     QCHECK_EQ(weights_shape.dimensions_count(), 2);
2440     const int rows = weights_shape.dims(0);
2441     const int cols = weights_shape.dims(1);
2442     QCHECK_EQ(rows % 4, 0);
2443     QCHECK_EQ(cols % 16, 0);
2444     CHECK_EQ(rows * cols, weights_data.size());
2445     // Compute the de-shuffled weights
2446     std::vector<uint8> deshuffled_data(weights_data.size());
2447     uint8* shuffled_data_ptr = weights_data.data();
2448     for (int r = 0; r < rows; r += 4) {
2449       for (int c = 0; c < cols; c += 16) {
2450         for (int i = 0; i < 4; i++) {
2451           uint8* deshuffled_data_ptr =
2452               deshuffled_data.data() + (r + i) * cols + c;
2453           for (int j = 0; j < 16; j++) {
2454             uint8 shuffled_val = *shuffled_data_ptr++;
2455             // Deshuffling isn't only about deshuffling the storage layout,
2456             // it's also about undoing the flipping of the sign bit, which is
2457             // performed on the shuffled weights.
2458             uint8 deshuffled_val = shuffled_val ^ 0x80;
2459             *deshuffled_data_ptr++ = deshuffled_val;
2460           }
2461         }
2462       }
2463     }
2464     CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
2465     // Switch this FC op to using the deshuffled weights.
2466     weights_data = std::move(deshuffled_data);
2467   }
2468 }
2469 
CopyMinMaxAndQuantizationRelatedFields(const Array & src,Array * dst)2470 void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
2471   if (src.minmax) {
2472     dst->GetOrCreateMinMax() = src.GetMinMax();
2473   }
2474   if (src.quantization_params) {
2475     dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
2476   }
2477   dst->narrow_range = src.narrow_range;
2478 }
2479 
2480 }  // namespace toco
2481