xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_wrapper.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
17 #define FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/status/status.h"
24 #include "absl/status/statusor.h"
25 #include "fcp/client/engine/caching_error_reporter.h"
26 #include "fcp/client/interruptible_runner.h"
27 #include "fcp/client/log_manager.h"
28 #include "fcp/client/simple_task_environment.h"
29 #include "tensorflow/lite/delegates/flex/delegate.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/model_builder.h"
32 
33 namespace fcp {
34 namespace client {
35 namespace engine {
36 
37 struct OutputTensors {
38   std::vector<std::string> output_tensor_names;
39   std::vector<tensorflow::Tensor> output_tensors;
40 };
41 
42 // Options for TFLite interpreter.
43 struct TfLiteInterpreterOptions {
44   // When true, TFLite uses dynamic tensor allocation and release tensors that
45   // are no longer needed.
46   bool ensure_dynamic_tensors_are_released = false;
47   // When the threshold is zero, dynamic allocation is not enabled for any
48   // tensor.
49   int32_t large_tensor_threshold_for_dynamic_allocation = 0;
50   // Whether to disable the graph-reordering optimization that clusters delegate
51   // ops together.
52   bool disable_delegate_clustering = false;
53 };
54 
55 // A class to call into TFLite.
56 // All functions in this interface indicate errors as follows:
57 // - CANCELLED: interrupted execution
58 // - INVALID_ARGUMENT:
59 //    1. Invalid model.
60 //    2. Initialization failure for TFLite required classes such as Interpreter,
61 //    Delegate etc.
62 //    3. Missing required inputs.
63 //    4. TensorFlow error. The TensorFlow error messages are included in the
64 //    Status message.
65 // This class supports aborting ongoing calls, by polling the provided
66 // should_abort function.
67 // Parameters:
68 //    1. model: The serialized TFLite model.
69 //    2. should_abort: A function which will be polled periodically to determine
70 //    if the computation should be aborted.
71 //    3. timing_config: The TimingConfig for an InterruptibleRunner.
72 //    4. log_manager: A LogManager.
73 //    5. inputs: A hashmap which has input tensor name as key, tensor data as
74 //    value.
75 //    6. output_names: The names of the output tensors. The order for these
76 //    tensor names must be deterministic.
77 class TfLiteWrapper {
78  public:
79   static absl::StatusOr<std::unique_ptr<TfLiteWrapper>> Create(
80       const std::string& model, std::function<bool()> should_abort,
81       const InterruptibleRunner::TimingConfig& timing_config,
82       LogManager* log_manager,
83       std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
84       std::vector<std::string> output_names,
85       const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads);
86 
87   // Wrapper around TfLite's Interpreter::Invoke method.
88   // If the run succeeds, a vector of output tensors (empty if there's no
89   // output tensors), or CANCELLED if the training run was cancelled or
90   // INVALID_ARGUMENT for the rest of errors.
91   absl::StatusOr<OutputTensors> Run();
92 
93  private:
TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<CachingErrorReporter> error_reporter,tflite::TfLiteDelegateUniquePtr delegate,std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<InterruptibleRunner> interruptible_runner,std::vector<std::string> output_names)94   TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
95                 std::unique_ptr<CachingErrorReporter> error_reporter,
96                 tflite::TfLiteDelegateUniquePtr delegate,
97                 std::unique_ptr<tflite::Interpreter> interpreter,
98                 std::unique_ptr<InterruptibleRunner> interruptible_runner,
99                 std::vector<std::string> output_names)
100       : model_(std::move(model)),
101         error_reporter_(std::move(error_reporter)),
102         delegate_(std::move(delegate)),
103         interpreter_(std::move(interpreter)),
104         interruptible_runner_(std::move(interruptible_runner)),
105         output_names_(std::move(output_names)) {}
106   absl::Status ConvertTfLiteStatus(TfLiteStatus status);
107   absl::StatusOr<OutputTensors> ConstructOutputs();
108 
109   std::unique_ptr<tflite::FlatBufferModel> model_;
110   std::unique_ptr<CachingErrorReporter> error_reporter_;
111   tflite::TfLiteDelegateUniquePtr delegate_;
112   std::unique_ptr<tflite::Interpreter> interpreter_;
113   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
114   const std::vector<std::string> output_names_;
115 };
116 
117 }  // namespace engine
118 }  // namespace client
119 }  // namespace fcp
120 
121 #endif  // FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
122