1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
16 #define TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
17 
18 #include <stddef.h>
19 #include <unistd.h>
20 
21 #include <cstdlib>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/lite/experimental/acceleration/mini_benchmark/status_codes.h"
28 #include "tensorflow/lite/model_builder.h"
29 
30 namespace tflite {
31 namespace acceleration {
32 
33 // Class to load the Model.
34 class ModelLoader {
35  public:
~ModelLoader()36   virtual ~ModelLoader() {}
37 
38   // Return whether the model is loaded successfully.
39   virtual MinibenchmarkStatus Init();
40 
GetModel()41   const FlatBufferModel* GetModel() const { return model_.get(); }
42 
43  protected:
44   // ModelLoader() = default;
45 
46   // Interface for subclass to create model_. If failed, Init() will return the
47   // error status; If succeeded but model_ is null, Init() function will return
48   // ModelBuildFailed.
49   virtual MinibenchmarkStatus InitInternal() = 0;
50 
51   std::unique_ptr<FlatBufferModel> model_;
52 };
53 
54 // Load the Model from a file path.
55 class PathModelLoader : public ModelLoader {
56  public:
PathModelLoader(absl::string_view model_path)57   explicit PathModelLoader(absl::string_view model_path)
58       : ModelLoader(), model_path_(model_path) {}
59 
60  protected:
61   MinibenchmarkStatus InitInternal() override;
62 
63  private:
64   const std::string model_path_;
65 };
66 
67 #ifndef _WIN32
68 // Load the Model from a file descriptor. This class is not available on
69 // Windows.
70 class MmapModelLoader : public ModelLoader {
71  public:
72   // Create the model loader from file descriptor. The model_fd only has to be
73   // valid for the duration of the constructor (it's dup'ed inside).
MmapModelLoader(int model_fd,size_t model_offset,size_t model_size)74   MmapModelLoader(int model_fd, size_t model_offset, size_t model_size)
75       : ModelLoader(),
76         model_fd_(dup(model_fd)),
77         model_offset_(model_offset),
78         model_size_(model_size) {}
79 
~MmapModelLoader()80   ~MmapModelLoader() override {
81     if (model_fd_ >= 0) {
82       close(model_fd_);
83     }
84   }
85 
86  protected:
87   MinibenchmarkStatus InitInternal() override;
88 
89  private:
90   const int model_fd_ = -1;
91   const size_t model_offset_ = 0;
92   const size_t model_size_ = 0;
93 };
94 
95 // Load the Model from a pipe file descriptor.
96 // IMPORTANT: This class tries to read the model from a pipe file descriptor,
97 // and the caller needs to ensure that this pipe should be read from in a
98 // different process / thread than written to. It may block when running in the
99 // same process / thread.
100 class PipeModelLoader : public ModelLoader {
101  public:
PipeModelLoader(int pipe_fd,size_t model_size)102   PipeModelLoader(int pipe_fd, size_t model_size)
103       : ModelLoader(), pipe_fd_(pipe_fd), model_size_(model_size) {}
104 
105   // Move only.
106   PipeModelLoader(PipeModelLoader&&) = default;
107   PipeModelLoader& operator=(PipeModelLoader&&) = default;
108 
~PipeModelLoader()109   ~PipeModelLoader() override { std::free(model_buffer_); }
110 
111  protected:
112   // Read the serialized Model from read_pipe_fd. Return ModelReadFailed if the
113   // readin bytes is less than read_size. This function also closes the
114   // read_pipe_fd and write_pipe_fd.
115   MinibenchmarkStatus InitInternal() override;
116 
117  private:
118   const int pipe_fd_ = -1;
119   const size_t model_size_ = 0;
120   uint8_t* model_buffer_ = nullptr;
121 };
122 
123 #endif  // !_WIN32
124 
125 // Create the model loader from a string path. Path can be one of the following:
126 // 1) File descriptor path: path must be in the format of
127 // "fd:%model_fd%:%model_offset%:%model_size%". Returns null if path cannot be
128 // parsed.
129 // 2) Pipe descriptor path: path must be in the format of
130 // "pipe:%read_pipe%:%write_pipe%:%model_size%". This function also closes the
131 // write_pipe when write_pipe >= 0, so it should be called at the read thread /
132 // process. Returns null if path cannot be parsed.
133 // 3) File path: Always return a PathModelLoader.
134 // NOTE: This helper function is designed for creating the ModelLoader from
135 // command line parameters. Prefer to use the ModelLoader constructors directly
136 // when possible.
137 std::unique_ptr<ModelLoader> CreateModelLoaderFromPath(absl::string_view path);
138 
139 }  // namespace acceleration
140 }  // namespace tflite
141 
142 #endif  // TENSORFLOW_LITE_EXPERIMENTAL_ACCELERATION_MINI_BENCHMARK_MODEL_LOADER_H_
143