xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/calibration/calibrator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/optimize/calibration/calibrator.h"
16 
17 #include <fstream>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/memory/memory.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/minimal_logging.h"
34 #include "tensorflow/lite/model.h"
35 #include "tensorflow/lite/op_resolver.h"
36 #include "tensorflow/lite/schema/schema_generated.h"
37 #include "tensorflow/lite/schema/schema_utils.h"
38 #include "tensorflow/lite/stderr_reporter.h"
39 #include "tensorflow/lite/string_util.h"
40 #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h"
41 #include "tensorflow/lite/tools/optimize/calibration/calibration_common.h"
42 #include "tensorflow/lite/tools/optimize/calibration/calibration_logger.h"
43 #include "tensorflow/lite/tools/optimize/calibration/calibration_reader.h"
44 #include "tensorflow/lite/tools/optimize/calibration/custom_logging_ops/lstm.h"
45 #include "tensorflow/lite/tools/optimize/calibration/logging_op.h"
46 #include "tensorflow/lite/tools/optimize/calibration/logging_op_resolver.h"
47 
48 namespace tflite {
49 namespace optimize {
50 namespace calibration {
51 
52 namespace {
53 
54 // Calibrator is used to hold information that can be accessed during kernel
55 // invocations.
56 // TfLite kernel invocations are C functions and cannot look at the global
57 // structure of the graph. Calibrator allows the kernel invoke functions to
58 // access the global structure of graph and know which node is currently being
59 // executed. This also allows us to write a simple kernel invoke wrapper
60 // (see LoggingEval) that can work for most builtin ops.
61 class Calibrator {
62  public:
Calibrator(const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_ptr_opinfo_map,std::unique_ptr<LoggingOpResolver> logging_op_resolver,ErrorReporter * error_reporter)63   Calibrator(const std::unordered_map<const TfLiteNode*, OperatorInfo>&
64                  node_ptr_opinfo_map,
65              std::unique_ptr<LoggingOpResolver> logging_op_resolver,
66              ErrorReporter* error_reporter)
67       : node_ptr_opinfo_map_(node_ptr_opinfo_map),
68         logging_op_resolver_(std::move(logging_op_resolver)),
69         error_reporter_(error_reporter) {
70     logger_ = std::make_unique<Logger>();
71   }
72 
73   // Returns the wrapped kernel invoke function |TfLiteRegistration.invoke|.
74   KernelEvalFuncPtr GetKernelInvoke(const TfLiteNode* node) const;
75 
76   // Gets the instance of logger associated with the current context.
GetLogger() const77   Logger* GetLogger() const { return logger_.get(); }
78 
79   // Gets the error reporter.
GetErrorReporter() const80   ErrorReporter* GetErrorReporter() const { return error_reporter_; }
81 
82   // Gets the operator information about the given TfLiteNode.
GetOpInfo(const TfLiteNode * node) const83   const OperatorInfo& GetOpInfo(const TfLiteNode* node) const {
84     return node_ptr_opinfo_map_.at(node);
85   }
86 
GetNodesUnderCalibration()87   std::vector<const TfLiteNode*> GetNodesUnderCalibration() {
88     std::vector<const TfLiteNode*> nodes;
89     nodes.reserve(node_ptr_opinfo_map_.size());
90     for (const auto& entry : node_ptr_opinfo_map_) {
91       nodes.push_back(entry.first);
92     }
93     return nodes;
94   }
95 
96  private:
97   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map_;
98   std::unique_ptr<LoggingOpResolver> logging_op_resolver_;
99   const std::unordered_map<int, OperatorInfo> index_opinfo_;
100   std::unique_ptr<Logger> logger_;
101   ErrorReporter* error_reporter_;
102 };
103 
GetKernelInvoke(const TfLiteNode * node) const104 KernelEvalFuncPtr Calibrator::GetKernelInvoke(const TfLiteNode* node) const {
105   auto op_info = node_ptr_opinfo_map_.at(node);
106   if (op_info.is_custom_op) {
107     return logging_op_resolver_->GetWrappedKernelInvoke(op_info.name.c_str(),
108                                                         op_info.version);
109   }
110   return logging_op_resolver_->GetWrappedKernelInvoke(op_info.builtin_op_code,
111                                                       op_info.version);
112 }
113 
114 // A registry of |Calibrator| objects per |TfLiteContext|.
115 // This global registry is needed to access |Calibrator| objects in the kernel
116 // invoke functions i.e. |TfLiteRegistration.invoke|.
117 // Kernel invoke functions are C functions that have limited access to
118 // |TfLiteContext|. Kernel invoke functions don't have access to global state of
119 // graph. That means during a kernel invocation, the function cannot know which
120 // node it was invoked for. E.g. in case of a model with |Conv| op at two
121 // locations, there is no easy way for the Conv.invoke function to disambiguate
122 // the calls.
123 //
124 // For calibration we solve this problem by creating a map of calibrators
125 // per |TfLiteContext|. This map is |GlobalCalibrationRegistry|.
126 //
127 // This registry is then accessed using a global getter function:
128 // |GetCalibratorRegistry|.
129 // E.g.
130 // TfLiteStatus SomeKernelInvokeFn(TfLiteContext* context, TfLiteNode* node) {
131 //   .... code ....
132 //   auto registry = GetCalibratorRegistry();
133 //   auto calibrator = registry->GetCalibrator(context);
134 //   ..... code ....
135 //  }
136 //
137 // This way the kernel invoke functions can get the access to the Calibrator
138 // object associated with the |TfLiteContext|.
139 class GlobalCalibratorRegistry {
140  public:
141   // Get the |Calibrator| associated with given context, returns null if no
142   // calibrator is associated with the given context.
GetCalibrator(const TfLiteNode * node) const143   Calibrator* GetCalibrator(const TfLiteNode* node) const {
144     if (node_to_calibrator_.find(node) == node_to_calibrator_.cend()) {
145       return nullptr;
146     }
147     return node_to_calibrator_.at(node);
148   }
149 
150   // Removes the association between calibrator and context.
151   // Note: This deletes the calibrator as well.
RemoveCalibrator(const TfLiteContext * context)152   void RemoveCalibrator(const TfLiteContext* context) {
153     Calibrator* calibrator = calibrator_registry_.at(context).get();
154     auto nodes = calibrator->GetNodesUnderCalibration();
155     for (auto node : nodes) {
156       node_to_calibrator_.erase(node);
157     }
158     calibrator_registry_.erase(context);
159   }
160 
161   // Creates an instance of |Calibrator|.
162   // Registry owns the |Calibrator| object which can be deleted by calling
163   // |RemoveCalibrator|.
CreateCalibrator(const TfLiteContext * context,const std::unordered_map<const TfLiteNode *,OperatorInfo> & node_to_opinfo,std::unique_ptr<LoggingOpResolver> logging_op_resolver,Calibrator ** calibrator_ptr,ErrorReporter * reporter)164   TfLiteStatus CreateCalibrator(
165       const TfLiteContext* context,
166       const std::unordered_map<const TfLiteNode*, OperatorInfo>& node_to_opinfo,
167       std::unique_ptr<LoggingOpResolver> logging_op_resolver,
168       Calibrator** calibrator_ptr, ErrorReporter* reporter) {
169     if (calibrator_registry_.find(context) != calibrator_registry_.cend()) {
170       reporter->Report(
171           "Failed to create calibrator, context already registered.");
172       return kTfLiteError;
173     }
174     auto calibrator = std::make_unique<Calibrator>(
175         node_to_opinfo, std::move(logging_op_resolver), reporter);
176     calibrator_registry_[context] = std::move(calibrator);
177     *calibrator_ptr = calibrator_registry_.at(context).get();
178     for (const auto& entry : node_to_opinfo) {
179       node_to_calibrator_[entry.first] = *calibrator_ptr;
180     }
181     return kTfLiteOk;
182   }
183 
184  private:
185   absl::flat_hash_map<const TfLiteContext*, std::unique_ptr<Calibrator>>
186       calibrator_registry_;
187   absl::flat_hash_map<const TfLiteNode*, Calibrator*> node_to_calibrator_;
188 };
189 
GetCalibratorRegistry()190 GlobalCalibratorRegistry* GetCalibratorRegistry() {
191   static GlobalCalibratorRegistry* registry = new GlobalCalibratorRegistry();
192   return registry;
193 }
194 
195 // Get the logging kernel if there are any.
196 // TODO(jianlijianli): extend this to support multiple recipe for the same
197 // model.
GetLoggingEvalFunc(TfLiteContext * context,TfLiteNode * node,int builtin_op_code)198 logging_kernel_func_ptr GetLoggingEvalFunc(TfLiteContext* context,
199                                            TfLiteNode* node,
200                                            int builtin_op_code) {
201   switch (builtin_op_code) {
202     case BuiltinOperator_LSTM: {
203       if (node->intermediates->size == 12) {
204         return tflite::optimize::calibration::custom::lstm_logging_kernel;
205       }
206       return tflite::optimize::calibration::builtin::lstm_logging_kernel;
207     }
208     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
209       return tflite::optimize::calibration::builtin::
210           unidirectional_sequence_lstm_logging_kernel;
211     default:
212       return nullptr;
213   }
214 }
215 
216 // A wrapper implementation for |TfLiteRegistration.invoke| that logs inputs,
217 // invokes the wrapped implementation and then logs the outputs.
LoggingEval(TfLiteContext * context,TfLiteNode * node)218 TfLiteStatus LoggingEval(TfLiteContext* context, TfLiteNode* node) {
219   Calibrator* calibrator = GetCalibratorRegistry()->GetCalibrator(node);
220 
221   if (!calibrator) {
222     TF_LITE_KERNEL_LOG(context, "No calibrator found for context.");
223     return kTfLiteError;
224   }
225 
226   auto kernel_invoke = calibrator->GetKernelInvoke(node);
227   auto logger = calibrator->GetLogger();
228   auto op_info = calibrator->GetOpInfo(node);
229   auto error_reporter = calibrator->GetErrorReporter();
230 
231   for (int i : op_info.loggable_inputs) {
232     auto tensor = context->tensors[i];
233     TF_LITE_ENSURE_STATUS(
234         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
235                                tensor.bytes / sizeof(float), error_reporter));
236   }
237   auto builtin_op_code = calibrator->GetOpInfo(node).builtin_op_code;
238   auto kernel_invoke_intermediate =
239       GetLoggingEvalFunc(context, node, builtin_op_code);
240   if (kernel_invoke_intermediate == nullptr) {
241     TF_LITE_ENSURE_STATUS(kernel_invoke(context, node));
242   } else {
243     TF_LITE_ENSURE_STATUS(
244         kernel_invoke_intermediate(context, op_info.subgraph_index, node,
245                                    calibrator->GetLogger(), error_reporter));
246   }
247 
248   // TODO(shashishekhar): An intermediate tensor in graph will get logged twice
249   // once as an input and second time as output. This doesn't change the min max
250   // values but is inefficient.
251   // Using moving average will also break this.
252 
253   // Log input again to make sure the state tensors are captured after lstm
254   // cell.
255   for (int i : op_info.loggable_inputs) {
256     auto tensor = context->tensors[i];
257     TF_LITE_ENSURE_STATUS(
258         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
259                                tensor.bytes / sizeof(float), error_reporter));
260   }
261 
262   for (int i : op_info.loggable_outputs) {
263     auto tensor = context->tensors[i];
264     TF_LITE_ENSURE_STATUS(
265         logger->LogTensorValue(op_info.subgraph_index, i, tensor.data.f,
266                                tensor.bytes / sizeof(float), error_reporter));
267   }
268 
269   return kTfLiteOk;
270 }
271 
272 // Returns the loggable tensors. Not all inputs and outputs need to be logged.
273 // For example, const weight tensors which have buffers associated with them
274 // don't need to be logged.
GetLoggableTensorIndices(const std::vector<int> & tensor_indices,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * tensor_buffers)275 std::vector<int> GetLoggableTensorIndices(
276     const std::vector<int>& tensor_indices,
277     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
278     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* tensor_buffers) {
279   std::vector<int> loggable;
280   for (auto tensor_index : tensor_indices) {
281     if (tensor_index == kTfLiteOptionalTensor) {
282       continue;
283     }
284     auto tensor = tensors->Get(tensor_index);
285     auto buffer_index = tensor->buffer();
286     const bool has_no_buffer =
287         (tensor_buffers->Get(buffer_index) == nullptr) ||
288         (tensor_buffers->Get(buffer_index)->data() == nullptr) ||
289         (tensor_buffers->Get(buffer_index)->data()->size() == 0);
290     if (has_no_buffer && tensor->type() == tflite::TensorType_FLOAT32) {
291       loggable.push_back(tensor_index);
292     }
293   }
294   return loggable;
295 }
296 
297 // Creates a mapping between the static model graph and the runtime TfLiteNode*
298 // nodes in the graph for the given context.
299 // This is done by querying the TfLiteContext for node and registrations using
300 // the |NodeInfoDelegateObserver|.
GetNodeOpInfoMapAndContext(const absl::flat_hash_map<std::tuple<int,int>,OperatorInfo> & node_to_opinfo,tflite::Interpreter * const interpreter,std::unordered_map<const TfLiteNode *,OperatorInfo> * node_ptr_opinfo_map,TfLiteContext ** context)301 TfLiteStatus GetNodeOpInfoMapAndContext(
302     const absl::flat_hash_map<std::tuple<int, int>, OperatorInfo>&
303         node_to_opinfo,
304     tflite::Interpreter* const interpreter,
305     std::unordered_map<const TfLiteNode*, OperatorInfo>* node_ptr_opinfo_map,
306     TfLiteContext** context) {
307   *context = interpreter->primary_subgraph().context();
308 
309   // Since we only consider the primary subgraph while populating
310   // node_to_opinfo, do the same here.
311   // Because Flex delegate can merge multiple op nodes into one Delegate node if
312   // they are located in a row, the size of the execution plan can be lesser
313   // than the size of the graph's op nodes.
314   TF_LITE_ENSURE(*context,
315                  interpreter->execution_plan().size() <= node_to_opinfo.size());
316   for (const auto& entry : node_to_opinfo) {
317     auto op_info = entry.second;
318     int subgraph_index, op_index;
319     std::tie(subgraph_index, op_index) = entry.first;
320     const auto* node_and_reg =
321         interpreter->node_and_registration(subgraph_index, op_index);
322     op_info.registration = &node_and_reg->second;
323     node_ptr_opinfo_map->insert({&node_and_reg->first, op_info});
324   }
325   return kTfLiteOk;
326 }
327 
GetOpName(const tflite::OperatorCode & opcode)328 string GetOpName(const tflite::OperatorCode& opcode) {
329   if (opcode.custom_code() != nullptr) {
330     return opcode.custom_code()->str();
331   }
332   return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)];
333 }
334 
335 // A |CalibrationReader| that owns the Calibrator.
336 class Reader : public CalibrationReader {
337  public:
Reader(const TfLiteContext * context,const Logger * logger)338   Reader(const TfLiteContext* context, const Logger* logger)
339       : CalibrationReader(logger), context_(context) {}
340 
~Reader()341   ~Reader() override { GetCalibratorRegistry()->RemoveCalibrator(context_); }
342 
343  private:
344   const TfLiteContext* context_;
345 };
346 
HasInputs(BuiltinOperator code)347 bool HasInputs(BuiltinOperator code) {
348   switch (code) {
349     case BuiltinOperator_CALL_ONCE:
350     case BuiltinOperator_VAR_HANDLE:
351     // Custom ops, including Flex ops, might not have inputs.
352     case BuiltinOperator_CUSTOM:
353       return false;
354     default:
355       return true;
356   }
357 }
358 
HasOutputs(BuiltinOperator code)359 bool HasOutputs(BuiltinOperator code) {
360   switch (code) {
361     case BuiltinOperator_ASSIGN_VARIABLE:
362     case BuiltinOperator_CALL_ONCE:
363     // Custom ops, including Flex ops, might not have outputs.
364     case BuiltinOperator_CUSTOM:
365       return false;
366     default:
367       return true;
368   }
369 }
370 
371 }  // namespace
372 
BuildLoggingInterpreter(const FlatBufferModel & model,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)373 TfLiteStatus BuildLoggingInterpreter(
374     const FlatBufferModel& model, const OpResolver& op_resolver,
375     std::unique_ptr<Interpreter>* interpreter,
376     std::unique_ptr<CalibrationReader>* calibration_reader) {
377   return BuildLoggingInterpreter(model.GetModel(), model.error_reporter(),
378                                  op_resolver, interpreter, calibration_reader);
379 }
380 
BuildLoggingInterpreter(const tflite::Model * tflite_model,ErrorReporter * error_reporter,const OpResolver & op_resolver,std::unique_ptr<Interpreter> * interpreter,std::unique_ptr<CalibrationReader> * calibration_reader)381 TfLiteStatus BuildLoggingInterpreter(
382     const tflite::Model* tflite_model, ErrorReporter* error_reporter,
383     const OpResolver& op_resolver, std::unique_ptr<Interpreter>* interpreter,
384     std::unique_ptr<CalibrationReader>* calibration_reader) {
385   if (error_reporter == nullptr) {
386     // Make sure error_reporter is valid.
387     error_reporter = DefaultErrorReporter();
388   }
389   auto subgraphs = tflite_model->subgraphs();
390   auto tensor_buffers = tflite_model->buffers();
391 
392   // Populate the node index to operator info map.
393   // We want to collect this information so we can use it during runtime to
394   // log details of which inputs and outputs.
395   // At runtime TFLite kernel invoke functions can only look into their
396   // own node in the graph (TFLiteNode*) and some limited context information.
397   absl::flat_hash_map<std::tuple<int, int>, OperatorInfo> node_to_opinfo;
398   BuiltinOpsSet builtin_op_and_versions;
399   CustomOpsSet custom_op_and_versions;
400 
401   for (size_t subgraph_index = 0; subgraph_index < subgraphs->size();
402        subgraph_index++) {
403     auto subgraph = subgraphs->Get(subgraph_index);
404     auto operator_codes = tflite_model->operator_codes();
405     auto operators = subgraph->operators();
406     auto tensors = subgraph->tensors();
407     if (!operators) {
408       continue;
409     }
410 
411     for (size_t i = 0; i < operators->size(); i++) {
412       OperatorInfo op_info;
413       op_info.subgraph_index = subgraph_index;
414       op_info.node_index = i;
415       auto op = operators->Get(i);
416       auto operator_code = operator_codes->Get(op->opcode_index());
417       op_info.builtin_op_code = GetBuiltinCode(operator_code);
418       op_info.name = GetOpName(*operator_code);
419       op_info.is_custom_op = operator_code->custom_code() != nullptr;
420       op_info.version = operator_code->version();
421 
422       auto op_inputs = op->inputs();
423       auto op_outputs = op->outputs();
424       if (op_inputs) {
425         op_info.inputs = std::vector<int>(op_inputs->begin(), op_inputs->end());
426       } else if (HasInputs(op_info.builtin_op_code)) {
427         TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing inputs",
428                    op_info.name.c_str());
429       }
430       if (op_outputs) {
431         op_info.outputs =
432             std::vector<int>(op_outputs->begin(), op_outputs->end());
433       } else if (HasOutputs(op_info.builtin_op_code)) {
434         TFLITE_LOG(TFLITE_LOG_WARNING, "Op %s missing outputs",
435                    op_info.name.c_str());
436       }
437       op_info.loggable_inputs =
438           GetLoggableTensorIndices(op_info.inputs, tensors, tensor_buffers);
439       op_info.loggable_outputs =
440           GetLoggableTensorIndices(op_info.outputs, tensors, tensor_buffers);
441       if (op_info.is_custom_op) {
442         op_info.registration =
443             op_resolver.FindOp(op_info.name.c_str(), operator_code->version());
444         custom_op_and_versions.insert(
445             {op_info.name.c_str(), operator_code->version()});
446       } else {
447         op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code),
448                                                   operator_code->version());
449         builtin_op_and_versions.insert(
450             {op_info.builtin_op_code, operator_code->version()});
451       }
452       std::tuple<int, int> key{subgraph_index, i};
453       node_to_opinfo[key] = op_info;
454     }
455   }
456 
457   // Prepare the logging op resolver to use |LoggingEval| for kernel
458   // invocations.
459   auto logging_op_resolver = std::make_unique<LoggingOpResolver>(
460       builtin_op_and_versions, custom_op_and_versions, op_resolver, LoggingEval,
461       error_reporter);
462   tflite::InterpreterBuilder(tflite_model, *logging_op_resolver,
463                              error_reporter)(interpreter);
464 
465   if (!(*interpreter)) {
466     error_reporter->Report("Failed to construct interpreter");
467     return kTfLiteError;
468   }
469 
470   // Compute the mapping between runtime and static graph structure, i.e.
471   // (TfLiteContext, TfLiteNode) -> OperatorInfo
472   std::unordered_map<const TfLiteNode*, OperatorInfo> node_ptr_opinfo_map;
473   TfLiteContext* context = nullptr;
474   TF_LITE_ENSURE_STATUS(GetNodeOpInfoMapAndContext(
475       node_to_opinfo, interpreter->get(), &node_ptr_opinfo_map, &context));
476 
477   Calibrator* calibrator = nullptr;
478   // Register a calibrator object for the context. This can be accessed
479   // during invocations by the logging kernels.
480   TF_LITE_ENSURE_STATUS(GetCalibratorRegistry()->CreateCalibrator(
481       context, node_ptr_opinfo_map, std::move(logging_op_resolver), &calibrator,
482       error_reporter));
483   *calibration_reader = std::unique_ptr<CalibrationReader>(
484       new Reader(context, calibrator->GetLogger()));
485 
486   return kTfLiteOk;
487 }
488 
489 }  // namespace calibration
490 }  // namespace optimize
491 }  // namespace tflite
492