1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/tools/evaluation/stages/tflite_inference_stage.h" 16 17 #include <cstring> 18 #include <fstream> 19 #include <string> 20 #include <utility> 21 22 #include "tensorflow/core/platform/logging.h" 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow/lite/profiling/time.h" 25 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" 26 #include "tensorflow/lite/tools/evaluation/utils.h" 27 28 namespace tflite { 29 namespace evaluation { 30 namespace { 31 GetTfliteModelInfo(const Interpreter & interpreter)32TfLiteModelInfo GetTfliteModelInfo(const Interpreter& interpreter) { 33 TfLiteModelInfo model_info; 34 for (int i : interpreter.inputs()) { 35 model_info.inputs.push_back(interpreter.tensor(i)); 36 } 37 for (int i : interpreter.outputs()) { 38 model_info.outputs.push_back(interpreter.tensor(i)); 39 } 40 return model_info; 41 } 42 43 } // namespace 44 UpdateModelInfo()45void TfliteInferenceStage::UpdateModelInfo() { 46 model_info_ = GetTfliteModelInfo(*interpreter_); 47 48 outputs_.clear(); 49 outputs_.reserve(interpreter_->outputs().size()); 50 for (int i : interpreter_->outputs()) { 51 TfLiteTensor* tensor = interpreter_->tensor(i); 52 outputs_.push_back(tensor->data.raw); 53 } 54 } 55 ResizeInputs(const std::vector<std::vector<int>> & shapes)56TfLiteStatus TfliteInferenceStage::ResizeInputs( 57 const std::vector<std::vector<int>>& shapes) { 58 const std::vector<int>& intepreter_inputs = interpreter_->inputs(); 59 if (intepreter_inputs.size() != shapes.size()) { 60 LOG(ERROR) << "New shape is not compatible"; 61 return kTfLiteError; 62 } 63 64 for (int j = 0; j < shapes.size(); ++j) { 65 int i = intepreter_inputs[j]; 66 TfLiteTensor* t = interpreter_->tensor(i); 67 if (t->type != kTfLiteString) { 68 TF_LITE_ENSURE_STATUS(interpreter_->ResizeInputTensor(i, shapes[j])); 69 } 70 } 71 72 TF_LITE_ENSURE_STATUS(interpreter_->AllocateTensors()); 73 UpdateModelInfo(); 74 return kTfLiteOk; 75 } 76 ApplyCustomDelegate(Interpreter::TfLiteDelegatePtr delegate)77TfLiteStatus TfliteInferenceStage::ApplyCustomDelegate( 78 Interpreter::TfLiteDelegatePtr delegate) { 79 if (!interpreter_) { 80 LOG(ERROR) << "Stage not initialized before calling ApplyCustomDelegate"; 81 return kTfLiteError; 82 } 83 // Skip if delegate is a nullptr. 84 if (!delegate) { 85 LOG(WARNING) 86 << "Tried to apply null TfLiteDelegatePtr to TfliteInferenceStage"; 87 return kTfLiteOk; 88 } 89 delegates_.push_back(std::move(delegate)); 90 TF_LITE_ENSURE_STATUS( 91 interpreter_->ModifyGraphWithDelegate(delegates_.back().get())); 92 UpdateModelInfo(); 93 return kTfLiteOk; 94 } 95 Init(const DelegateProviders * delegate_providers)96TfLiteStatus TfliteInferenceStage::Init( 97 const DelegateProviders* delegate_providers) { 98 if (!config_.specification().has_tflite_inference_params()) { 99 LOG(ERROR) << "TfliteInferenceParams not provided"; 100 return kTfLiteError; 101 } 102 auto& params = config_.specification().tflite_inference_params(); 103 if (!params.has_model_file_path()) { 104 LOG(ERROR) << "Model path not provided"; 105 return kTfLiteError; 106 } 107 std::ifstream model_check(params.model_file_path()); 108 if (!model_check.good()) { 109 LOG(ERROR) << "Model file not found"; 110 return kTfLiteError; 111 } 112 113 model_ = FlatBufferModel::BuildFromFile(params.model_file_path().c_str()); 114 115 bool apply_default_delegates = true; 116 if (delegate_providers != nullptr) { 117 const auto& provider_params = delegate_providers->GetAllParams(); 118 // When --use_xnnpack is explicitly set to false, to honor this, skip 119 // applying the XNNPACK delegate by default in TfLite runtime. 120 if (provider_params.HasParam("use_xnnpack") && 121 provider_params.HasValueSet<bool>("use_xnnpack") && 122 !provider_params.Get<bool>("use_xnnpack")) { 123 apply_default_delegates = false; 124 } 125 } 126 127 resolver_.reset( 128 apply_default_delegates 129 ? new ops::builtin::BuiltinOpResolver() 130 : new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); 131 InterpreterBuilder(*model_, *resolver_)(&interpreter_); 132 if (!interpreter_) { 133 LOG(ERROR) << "Could not build interpreter"; 134 return kTfLiteError; 135 } 136 interpreter_->SetNumThreads(params.num_threads()); 137 138 if (!delegate_providers) { 139 std::string error_message; 140 auto delegate = CreateTfLiteDelegate(params, &error_message); 141 if (delegate) { 142 delegates_.push_back(std::move(delegate)); 143 LOG(INFO) << "Successfully created " 144 << params.Delegate_Name(params.delegate()) << " delegate."; 145 } else { 146 LOG(WARNING) << error_message; 147 } 148 } else { 149 auto delegates = delegate_providers->CreateAllDelegates(params); 150 for (auto& one : delegates) delegates_.push_back(std::move(one.delegate)); 151 } 152 153 for (int i = 0; i < delegates_.size(); ++i) { 154 if (interpreter_->ModifyGraphWithDelegate(delegates_[i].get()) != 155 kTfLiteOk) { 156 LOG(FATAL) << "Failed to apply delegate " << i; 157 } 158 } 159 interpreter_->AllocateTensors(); 160 UpdateModelInfo(); 161 162 return kTfLiteOk; 163 } 164 Run()165TfLiteStatus TfliteInferenceStage::Run() { 166 if (!inputs_) { 167 LOG(ERROR) << "Input data not set"; 168 return kTfLiteError; 169 } 170 171 // Copy input data. 172 for (int i = 0; i < interpreter_->inputs().size(); ++i) { 173 TfLiteTensor* tensor = interpreter_->tensor(interpreter_->inputs()[i]); 174 tensor->data.raw = static_cast<char*>(inputs_->at(i)); 175 } 176 177 // Invoke. 178 auto& params = config_.specification().tflite_inference_params(); 179 for (int i = 0; i < params.invocations_per_run(); ++i) { 180 int64_t start_us = profiling::time::NowMicros(); 181 if (interpreter_->Invoke() != kTfLiteOk) { 182 LOG(ERROR) << "TFLite interpreter failed to invoke at run " << i; 183 return kTfLiteError; 184 } 185 latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us); 186 } 187 188 return kTfLiteOk; 189 } 190 LatestMetrics()191EvaluationStageMetrics TfliteInferenceStage::LatestMetrics() { 192 auto& params = config_.specification().tflite_inference_params(); 193 EvaluationStageMetrics metrics; 194 auto* latency_metrics = 195 metrics.mutable_process_metrics()->mutable_total_latency(); 196 latency_metrics->set_last_us(latency_stats_.newest()); 197 latency_metrics->set_max_us(latency_stats_.max()); 198 latency_metrics->set_min_us(latency_stats_.min()); 199 latency_metrics->set_sum_us(latency_stats_.sum()); 200 latency_metrics->set_avg_us(latency_stats_.avg()); 201 latency_metrics->set_std_deviation_us(latency_stats_.std_deviation()); 202 metrics.set_num_runs( 203 static_cast<int>(latency_stats_.count() / params.invocations_per_run())); 204 auto* inference_metrics = 205 metrics.mutable_process_metrics()->mutable_tflite_inference_metrics(); 206 inference_metrics->set_num_inferences(latency_stats_.count()); 207 return metrics; 208 } 209 210 } // namespace evaluation 211 } // namespace tflite 212