xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h"
16 
17 #include <cstring>
18 #include <fstream>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/profiling/time.h"
25 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
26 #include "tensorflow/lite/tools/evaluation/utils.h"
27 
28 namespace tflite {
29 namespace evaluation {
30 namespace {
31 
GetTfliteModelInfo(const Interpreter & interpreter)32 TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) {
33   TfLiteModelInfo model_info;
34   for (int i : interpreter.inputs()) {
35     model_info.inputs.push_back(interpreter.tensor(i));
36   }
37   for (int i : interpreter.outputs()) {
38     model_info.outputs.push_back(interpreter.tensor(i));
39   }
40   return model_info;
41 }
42 
43 }  // namespace
44 
UpdateModelInfo()45 void TfliteInferenceStage::UpdateModelInfo() {
46   model_info_ = GetTfliteModelInfo(*interpreter_);
47 
48   outputs_.clear();
49   outputs_.reserve(interpreter_->outputs().size());
50   for (int i : interpreter_->outputs()) {
51     TfLiteTensor* tensor = interpreter_->tensor(i);
52     outputs_.push_back(tensor->data.raw);
53   }
54 }
55 
ResizeInputs(const std::vector<std::vector<int>> & shapes)56 TfLiteStatus TfliteInferenceStage::ResizeInputs(
57     const std::vector<std::vector<int>>& shapes) {
58   const std::vector<int>& intepreter_inputs = interpreter_->inputs();
59   if (intepreter_inputs.size() != shapes.size()) {
60     LOG(ERROR) << "New shape is not compatible";
61     return kTfLiteError;
62   }
63 
64   for (int j = 0; j < shapes.size(); ++j) {
65     int i = intepreter_inputs[j];
66     TfLiteTensor* t = interpreter_->tensor(i);
67     if (t->type != kTfLiteString) {
68       TF_LITE_ENSURE_STATUS(interpreter_->ResizeInputTensor(i, shapes[j]));
69     }
70   }
71 
72   TF_LITE_ENSURE_STATUS(interpreter_->AllocateTensors());
73   UpdateModelInfo();
74   return kTfLiteOk;
75 }
76 
ApplyCustomDelegate(Interpreter::TfLiteDelegatePtr delegate)77 TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate(
78     Interpreter::TfLiteDelegatePtr delegate) {
79   if (!interpreter_) {
80     LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate";
81     return kTfLiteError;
82   }
83   // Skip if delegate is a nullptr.
84   if (!delegate) {
85     LOG(WARNING)
86         << "Tried to apply null TfLiteDelegatePtr to TfliteInferenceStage";
87     return kTfLiteOk;
88   }
89   delegates_.push_back(std::move(delegate));
90   TF_LITE_ENSURE_STATUS(
91       interpreter_->ModifyGraphWithDelegate(delegates_.back().get()));
92   UpdateModelInfo();
93   return kTfLiteOk;
94 }
95 
Init(const DelegateProviders * delegate_providers)96 TfLiteStatus TfliteInferenceStage::Init(
97     const DelegateProviders* delegate_providers) {
98   if (!config_.specification().has_tflite_inference_params()) {
99     LOG(ERROR) << "TfliteInferenceParams not provided";
100     return kTfLiteError;
101   }
102   auto& params = config_.specification().tflite_inference_params();
103   if (!params.has_model_file_path()) {
104     LOG(ERROR) << "Model path not provided";
105     return kTfLiteError;
106   }
107   std::ifstream model_check(params.model_file_path());
108   if (!model_check.good()) {
109     LOG(ERROR) << "Model file not found";
110     return kTfLiteError;
111   }
112 
113   model_ = FlatBufferModel::BuildFromFile(params.model_file_path().c_str());
114 
115   bool apply_default_delegates = true;
116   if (delegate_providers != nullptr) {
117     const auto& provider_params = delegate_providers->GetAllParams();
118     // When --use_xnnpack is explicitly set to false, to honor this, skip
119     // applying the XNNPACK delegate by default in TfLite runtime.
120     if (provider_params.HasParam("use_xnnpack") &&
121         provider_params.HasValueSet<bool>("use_xnnpack") &&
122         !provider_params.Get<bool>("use_xnnpack")) {
123       apply_default_delegates = false;
124     }
125   }
126 
127   resolver_.reset(
128       apply_default_delegates
129           ? new ops::builtin::BuiltinOpResolver()
130           : new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
131   InterpreterBuilder(*model_, *resolver_)(&interpreter_);
132   if (!interpreter_) {
133     LOG(ERROR) << "Could not build interpreter";
134     return kTfLiteError;
135   }
136   interpreter_->SetNumThreads(params.num_threads());
137 
138   if (!delegate_providers) {
139     std::string error_message;
140     auto delegate = CreateTfLiteDelegate(params, &error_message);
141     if (delegate) {
142       delegates_.push_back(std::move(delegate));
143       LOG(INFO) << "Successfully created "
144                 << params.Delegate_Name(params.delegate()) << " delegate.";
145     } else {
146       LOG(WARNING) << error_message;
147     }
148   } else {
149     auto delegates = delegate_providers->CreateAllDelegates(params);
150     for (auto& one : delegates) delegates_.push_back(std::move(one.delegate));
151   }
152 
153   for (int i = 0; i < delegates_.size(); ++i) {
154     if (interpreter_->ModifyGraphWithDelegate(delegates_[i].get()) !=
155         kTfLiteOk) {
156       LOG(FATAL) << "Failed to apply delegate " << i;
157     }
158   }
159   interpreter_->AllocateTensors();
160   UpdateModelInfo();
161 
162   return kTfLiteOk;
163 }
164 
Run()165 TfLiteStatus TfliteInferenceStage::Run() {
166   if (!inputs_) {
167     LOG(ERROR) << "Input data not set";
168     return kTfLiteError;
169   }
170 
171   // Copy input data.
172   for (int i = 0; i < interpreter_->inputs().size(); ++i) {
173     TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]);
174     tensor->data.raw = static_cast<char*>(inputs_->at(i));
175   }
176 
177   // Invoke.
178   auto& params = config_.specification().tflite_inference_params();
179   for (int i = 0; i < params.invocations_per_run(); ++i) {
180     int64_t start_us = profiling::time::NowMicros();
181     if (interpreter_->Invoke() != kTfLiteOk) {
182       LOG(ERROR) << "TFLite interpreter failed to invoke at run " << i;
183       return kTfLiteError;
184     }
185     latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us);
186   }
187 
188   return kTfLiteOk;
189 }
190 
LatestMetrics()191 EvaluationStageMetrics TfliteInferenceStage::LatestMetrics() {
192   auto& params = config_.specification().tflite_inference_params();
193   EvaluationStageMetrics metrics;
194   auto* latency_metrics =
195       metrics.mutable_process_metrics()->mutable_total_latency();
196   latency_metrics->set_last_us(latency_stats_.newest());
197   latency_metrics->set_max_us(latency_stats_.max());
198   latency_metrics->set_min_us(latency_stats_.min());
199   latency_metrics->set_sum_us(latency_stats_.sum());
200   latency_metrics->set_avg_us(latency_stats_.avg());
201   latency_metrics->set_std_deviation_us(latency_stats_.std_deviation());
202   metrics.set_num_runs(
203       static_cast<int>(latency_stats_.count() / params.invocations_per_run()));
204   auto* inference_metrics =
205       metrics.mutable_process_metrics()->mutable_tflite_inference_metrics();
206   inference_metrics->set_num_inferences(latency_stats_.count());
207   return metrics;
208 }
209 
210 }  // namespace evaluation
211 }  // namespace tflite
212