1 /* 2 * Copyright 2021 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 17 #define FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "absl/status/status.h" 24 #include "absl/status/statusor.h" 25 #include "fcp/client/engine/caching_error_reporter.h" 26 #include "fcp/client/interruptible_runner.h" 27 #include "fcp/client/log_manager.h" 28 #include "fcp/client/simple_task_environment.h" 29 #include "tensorflow/lite/delegates/flex/delegate.h" 30 #include "tensorflow/lite/interpreter.h" 31 #include "tensorflow/lite/model_builder.h" 32 33 namespace fcp { 34 namespace client { 35 namespace engine { 36 37 struct OutputTensors { 38 std::vector<std::string> output_tensor_names; 39 std::vector<tensorflow::Tensor> output_tensors; 40 }; 41 42 // Options for TFLite interpreter. 43 struct TfLiteInterpreterOptions { 44 // When true, TFLite uses dynamic tensor allocation and release tensors that 45 // are no longer needed. 46 bool ensure_dynamic_tensors_are_released = false; 47 // When the threshold is zero, dynamic allocation is not enabled for any 48 // tensor. 49 int32_t large_tensor_threshold_for_dynamic_allocation = 0; 50 // Whether to disable the graph-reordering optimization that clusters delegate 51 // ops together. 52 bool disable_delegate_clustering = false; 53 }; 54 55 // A class to call into TFLite. 56 // All functions in this interface indicate errors as follows: 57 // - CANCELLED: interrupted execution 58 // - INVALID_ARGUMENT: 59 // 1. Invalid model. 60 // 2. Initialization failure for TFLite required classes such as Interpreter, 61 // Delegate etc. 62 // 3. Missing required inputs. 63 // 4. TensorFlow error. The TensorFlow error messages are included in the 64 // Status message. 65 // This class supports aborting ongoing calls, by polling the provided 66 // should_abort function. 67 // Parameters: 68 // 1. model: The serialized TFLite model. 69 // 2. should_abort: A function which will be polled periodically to determine 70 // if the computation should be aborted. 71 // 3. timing_config: The TimingConfig for an InterruptibleRunner. 72 // 4. log_manager: A LogManager. 73 // 5. inputs: A hashmap which has input tensor name as key, tensor data as 74 // value. 75 // 6. output_names: The names of the output tensors. The order for these 76 // tensor names must be deterministic. 77 class TfLiteWrapper { 78 public: 79 static absl::StatusOr<std::unique_ptr<TfLiteWrapper>> Create( 80 const std::string& model, std::function<bool()> should_abort, 81 const InterruptibleRunner::TimingConfig& timing_config, 82 LogManager* log_manager, 83 std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs, 84 std::vector<std::string> output_names, 85 const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads); 86 87 // Wrapper around TfLite's Interpreter::Invoke method. 88 // If the run succeeds, a vector of output tensors (empty if there's no 89 // output tensors), or CANCELLED if the training run was cancelled or 90 // INVALID_ARGUMENT for the rest of errors. 91 absl::StatusOr<OutputTensors> Run(); 92 93 private: TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<CachingErrorReporter> error_reporter,tflite::TfLiteDelegateUniquePtr delegate,std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<InterruptibleRunner> interruptible_runner,std::vector<std::string> output_names)94 TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model, 95 std::unique_ptr<CachingErrorReporter> error_reporter, 96 tflite::TfLiteDelegateUniquePtr delegate, 97 std::unique_ptr<tflite::Interpreter> interpreter, 98 std::unique_ptr<InterruptibleRunner> interruptible_runner, 99 std::vector<std::string> output_names) 100 : model_(std::move(model)), 101 error_reporter_(std::move(error_reporter)), 102 delegate_(std::move(delegate)), 103 interpreter_(std::move(interpreter)), 104 interruptible_runner_(std::move(interruptible_runner)), 105 output_names_(std::move(output_names)) {} 106 absl::Status ConvertTfLiteStatus(TfLiteStatus status); 107 absl::StatusOr<OutputTensors> ConstructOutputs(); 108 109 std::unique_ptr<tflite::FlatBufferModel> model_; 110 std::unique_ptr<CachingErrorReporter> error_reporter_; 111 tflite::TfLiteDelegateUniquePtr delegate_; 112 std::unique_ptr<tflite::Interpreter> interpreter_; 113 std::unique_ptr<InterruptibleRunner> interruptible_runner_; 114 const std::vector<std::string> output_names_; 115 }; 116 117 } // namespace engine 118 } // namespace client 119 } // namespace fcp 120 121 #endif // FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 122