1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ 16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ 17 18 #include <stddef.h> 19 #include <unistd.h> 20 21 #include <cstdlib> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "absl/strings/string_view.h" 27 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h" 28 #include "tensorflow/lite/model_builder.h" 29 30 namespace tflite { 31 namespace acceleration { 32 33 // Class to load the Model. 34 class ModelLoader { 35 public: ~ModelLoader()36 virtual ~ModelLoader() {} 37 38 // Return whether the model is loaded successfully. 39 virtual MinibenchmarkStatus Init(); 40 GetModel()41 const FlatBufferModel* GetModel() const { return model_.get(); } 42 43 protected: 44 // ModelLoader() = default; 45 46 // Interface for subclass to create model_. If failed, Init() will return the 47 // error status; If succeeded but model_ is null, Init() function will return 48 // ModelBuildFailed. 49 virtual MinibenchmarkStatus InitInternal() = 0; 50 51 std::unique_ptr<FlatBufferModel> model_; 52 }; 53 54 // Load the Model from a file path. 55 class PathModelLoader : public ModelLoader { 56 public: PathModelLoader(absl::string_view model_path)57 explicit PathModelLoader(absl::string_view model_path) 58 : ModelLoader(), model_path_(model_path) {} 59 60 protected: 61 MinibenchmarkStatus InitInternal() override; 62 63 private: 64 const std::string model_path_; 65 }; 66 67 #ifndef _WIN32 68 // Load the Model from a file descriptor. This class is not available on 69 // Windows. 70 class MmapModelLoader : public ModelLoader { 71 public: 72 // Create the model loader from file descriptor. The model_fd only has to be 73 // valid for the duration of the constructor (it's dup'ed inside). MmapModelLoader(int model_fd,size_t model_offset,size_t model_size)74 MmapModelLoader(int model_fd, size_t model_offset, size_t model_size) 75 : ModelLoader(), 76 model_fd_(dup(model_fd)), 77 model_offset_(model_offset), 78 model_size_(model_size) {} 79 ~MmapModelLoader()80 ~MmapModelLoader() override { 81 if (model_fd_ >= 0) { 82 close(model_fd_); 83 } 84 } 85 86 protected: 87 MinibenchmarkStatus InitInternal() override; 88 89 private: 90 const int model_fd_ = -1; 91 const size_t model_offset_ = 0; 92 const size_t model_size_ = 0; 93 }; 94 95 // Load the Model from a pipe file descriptor. 96 // IMPORTANT: This class tries to read the model from a pipe file descriptor, 97 // and the caller needs to ensure that this pipe should be read from in a 98 // different process / thread than written to. It may block when running in the 99 // same process / thread. 100 class PipeModelLoader : public ModelLoader { 101 public: PipeModelLoader(int pipe_fd,size_t model_size)102 PipeModelLoader(int pipe_fd, size_t model_size) 103 : ModelLoader(), pipe_fd_(pipe_fd), model_size_(model_size) {} 104 105 // Move only. 106 PipeModelLoader(PipeModelLoader&&) = default; 107 PipeModelLoader& operator=(PipeModelLoader&&) = default; 108 ~PipeModelLoader()109 ~PipeModelLoader() override { std::free(model_buffer_); } 110 111 protected: 112 // Read the serialized Model from read_pipe_fd. Return ModelReadFailed if the 113 // readin bytes is less than read_size. This function also closes the 114 // read_pipe_fd and write_pipe_fd. 115 MinibenchmarkStatus InitInternal() override; 116 117 private: 118 const int pipe_fd_ = -1; 119 const size_t model_size_ = 0; 120 uint8_t* model_buffer_ = nullptr; 121 }; 122 123 #endif // !_WIN32 124 125 // Create the model loader from a string path. Path can be one of the following: 126 // 1) File descriptor path: path must be in the format of 127 // "fd:%model_fd%:%model_offset%:%model_size%". Returns null if path cannot be 128 // parsed. 129 // 2) Pipe descriptor path: path must be in the format of 130 // "pipe:%read_pipe%:%write_pipe%:%model_size%". This function also closes the 131 // write_pipe when write_pipe >= 0, so it should be called at the read thread / 132 // process. Returns null if path cannot be parsed. 133 // 3) File path: Always return a PathModelLoader. 134 // NOTE: This helper function is designed for creating the ModelLoader from 135 // command line parameters. Prefer to use the ModelLoader constructors directly 136 // when possible. 137 std::unique_ptr<ModelLoader> CreateModelLoaderFromPath(absl::string_view path); 138 139 } // namespace acceleration 140 } // namespace tflite 141 142 #endif // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_ 143