xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/saved_model/saved_model.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_
16 #define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_
17 
18 #include <functional>
19 #include <limits>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/span.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/protobuf/meta_graph.pb.h"
33 #include "tensorflow/core/tfrt/fallback/fallback_state.h"
34 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
35 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h"
36 #include "tensorflow/core/tfrt/runtime/runtime.h"
37 #include "tensorflow/core/tfrt/tpu/tpu_resources.h"  // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource
38 #include "tfrt/host_context/function.h"  // from @tf_runtime
39 #include "tfrt/host_context/request_deadline_tracker.h"  // from @tf_runtime
40 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
41 
42 namespace tfrt {
43 class BEFFile;
44 class HostContext;
45 }  // namespace tfrt
46 
47 namespace tensorflow {
48 namespace tfrt_stub {
49 
50 // TODO(tfrt-dev): Replace tfrt::TensorSpec with tensorflow::TensorSpec once the
51 // latter is checked in.
52 struct TensorSpec {
53   tensorflow::DataType dtype;
54   tensorflow::PartialTensorShape shape;
55 
TensorSpecTensorSpec56   explicit TensorSpec(tensorflow::DataType dtype) : dtype(dtype) {}
TensorSpecTensorSpec57   TensorSpec(tensorflow::DataType dtype, tensorflow::PartialTensorShape shape)
58       : dtype(dtype), shape(std::move(shape)) {}
59 };
60 
61 inline bool operator==(const TensorSpec& a, const TensorSpec& b) {
62   return a.dtype == b.dtype && a.shape.IsIdenticalTo(b.shape);
63 }
64 
65 namespace internal {
66 
67 struct Signature {
68   std::vector<tensorflow::Tensor> captures;
69 
70   // The following three fields should have the same size.
71   std::vector<std::string> input_names;
72   std::vector<TensorSpec> input_specs;
73   std::vector<std::string> input_devices;
74 
75   // The following two fields should have the same size.
76   std::vector<std::string> output_names;
77   std::vector<TensorSpec> output_specs;
78 };
79 
80 }  // namespace internal
81 
82 class FunctionMetadata {
83  public:
FunctionMetadata(const internal::Signature * signature)84   explicit FunctionMetadata(const internal::Signature* signature)
85       : signature_(signature) {
86     assert(signature);
87   }
88 
GetInputNames()89   const std::vector<std::string>& GetInputNames() const {
90     return signature_->input_names;
91   }
92 
GetInputSpecs()93   const std::vector<TensorSpec>& GetInputSpecs() const {
94     return signature_->input_specs;
95   }
96 
GetOutputNames()97   const std::vector<std::string>& GetOutputNames() const {
98     return signature_->output_names;
99   }
100 
GetOutputSpecs()101   const std::vector<TensorSpec>& GetOutputSpecs() const {
102     return signature_->output_specs;
103   }
104 
105  private:
106   friend class SavedModelImpl;
107 
108   const internal::Signature* signature_ = nullptr;
109 };
110 
111 // SavedModel represents the in-memory states (graphs and variables) loaded from
112 // a tensorflow saved model directory.
113 class SavedModel {
114  public:
115   struct Options {
OptionsOptions116     explicit Options(const Runtime* rt) : graph_execution_options(rt) {}
117 
118     // If the number of signagures is greater than the threshold, the loading of
119     // any signature (or signature combination) will be deferred until the first
120     // corresponding invocationof running. Otherwise, the individual signatures
121     // will be loaded along with the saved model.
122     int32_t lazy_loading_threshold = std::numeric_limits<int32_t>::max();
123 
124     // If true, we'll attempt to find MLArchive within the given loading path.
125     // If not found, will use the path as a normal SavedModel directory.
126     bool maybe_load_from_mla = false;
127 
128     GraphExecutionOptions graph_execution_options;
129   };
130 
131   // Per-request options.
132   using RunOptions = GraphExecutionRunOptions;
133 
SavedModel(const Runtime * runtime)134   explicit SavedModel(const Runtime* runtime) : runtime_(runtime) {
135     DCHECK(runtime_);
136   }
137   virtual ~SavedModel();
138 
runtime()139   const Runtime& runtime() const {
140     DCHECK(runtime_);
141     return *runtime_;
142   }
143   tfrt::HostContext* GetHostContext() const;
144 
145   // Returns meta graph def. Note that the graph_def field in the MetaGraphDef
146   // has already been removed.
147   //
148   // TODO(b/191931702): Change the method to return SignatureDefs instead.
149   virtual const tensorflow::MetaGraphDef& GetMetaGraphDef() const = 0;
150 
151   // Returns all the function names.
152   virtual std::vector<std::string> GetFunctionNames() const = 0;
153 
154   // Returns the `FunctionMetadata` for a function. If the function is not
155   // found, returns nullopt instead.
156   virtual std::optional<FunctionMetadata> GetFunctionMetadata(
157       absl::string_view func_name) const = 0;
158 
159   // Runs the signature specified by `name`. Both `inputs` and `outputs`
160   // are all host tensors. The `outputs` must be non-null. If the returned
161   // status is non-OK, the `outputs` are invalid.
162   virtual tensorflow::Status Run(const RunOptions& run_options,
163                                  absl::string_view name,
164                                  absl::Span<const tensorflow::Tensor> inputs,
165                                  std::vector<tensorflow::Tensor>* outputs) = 0;
166 
167   // Runs the signatures specified by `names`. Both `inputs` and `outputs` are
168   // all host tensors. The `outputs` must be non-null. If the returned status is
169   // non-OK, the `outputs` are invalid.
170   //
171   // NOTE: If the given signatures have overlapping input nodes, the input
172   // tensors for these overlapping nodes must be the same. Having different
173   // input tensors for overlapping nodes results UNDEFINED BEHAVIOR.
174   //
175   // NOTE: The input/output tensors can only be dense tensors (as opposed to
176   // sparse tensors or composite tensors).
177   virtual tensorflow::Status RunMultipleSignatures(
178       const RunOptions& run_options, absl::Span<const std::string> names,
179       absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,
180       std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) = 0;
181 
182   // Runs the graphs specified by the tensor names terminal tensors (eg. feed
183   // tensors, fetch tesnors) in the graph.
184   virtual tensorflow::Status RunByTensorNames(
185       const RunOptions& run_options,
186       absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
187       absl::Span<const std::string> output_tensor_names,
188       absl::Span<const std::string> target_node_names,
189       std::vector<tensorflow::Tensor>* outputs) = 0;
190 
191  private:
192   const Runtime* runtime_ = nullptr;
193 };
194 
195 class SavedModelImpl final : public SavedModel {
196  public:
197   struct JoinedSignature;
198 
199   // Loads all SignatureDefs in a MetaGraphDef that matches the `tags` in the
200   // tensorflow saved model from `saved_model_dir`. Refer to
201   // http://g3doc/learning/serving/g3doc/saved_model/overview.md
202   // for explanations on SavedModel.
203   //
204   // If `options.maybe_load_from_mla` is true, tries opening `saved_model_dir`
205   // as an MLA. If it's not an MLA, uses it as a normal SavedModel directory.
206   static std::unique_ptr<SavedModel> LoadSavedModel(
207       Options options, absl::string_view saved_model_dir,
208       const std::unordered_set<std::string>& tags, tensorflow::Status* status);
209 
210   SavedModelImpl(
211       Options options, tensorflow::MetaGraphDef meta_graph_def,
212       tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file,
213       absl::flat_hash_map<std::string, internal::Signature> signatures,
214       std::unique_ptr<FallbackState> fallback_state,
215       std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource,
216       std::unique_ptr<tfrt::ResourceContext> resource_context,
217       std::unique_ptr<GraphExecutor> graph_executor);
218 
219   ~SavedModelImpl() override;
220 
221   SavedModelImpl(const SavedModelImpl&) = delete;
222   SavedModelImpl& operator=(const SavedModelImpl&) = delete;
223 
224   const tensorflow::MetaGraphDef& GetMetaGraphDef() const override;
225 
226   std::vector<std::string> GetFunctionNames() const override;
227 
228   std::optional<FunctionMetadata> GetFunctionMetadata(
229       absl::string_view func_name) const override;
230 
231   tensorflow::Status Run(const RunOptions& run_options, absl::string_view name,
232                          absl::Span<const tensorflow::Tensor> inputs,
233                          std::vector<tensorflow::Tensor>* outputs) override;
234 
235   tensorflow::Status RunMultipleSignatures(
236       const RunOptions& run_options, absl::Span<const std::string> names,
237       absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs,
238       std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) override;
239 
240   tensorflow::Status RunByTensorNames(
241       const RunOptions& run_options,
242       absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs,
243       absl::Span<const std::string> output_tensor_names,
244       absl::Span<const std::string> target_node_names,
245       std::vector<tensorflow::Tensor>* outputs) override;
246 
247  private:
248   // The result of loading signature(s).
249   struct LoadingResult {
250     std::string name;
251     tfrt::BefBuffer bef;
252     tfrt::RCReference<tfrt::BEFFile> bef_file;
253     std::unique_ptr<tfrt::ResourceContext> resource_context;
254   };
255 
256   // Imports a subgraph as an MLIR module with the specified `input_nodes`,
257   // `output_nodes`.
258   tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSubgraph(
259       mlir::MLIRContext* context,
260       const tensorflow::GraphImportConfig::InputArrays& input_nodes,
261       const std::vector<std::string>& output_nodes,
262       const std::vector<std::string>& target_nodes);
263 
264   // Given the joined signature, loads the subgraph and returns loading result.
265   tensorflow::StatusOr<
266       std::reference_wrapper<const SavedModelImpl::LoadingResult>>
267   LoadJoinedSignature(const JoinedSignature& joined_signature)
268       TF_EXCLUSIVE_LOCKS_REQUIRED(loading_result_cache_mu_);
269 
270   // Returns the loading result given the signature names.
271   tensorflow::StatusOr<
272       std::reference_wrapper<const SavedModelImpl::LoadingResult>>
273   GetOrCreateLoadingResult(absl::Span<const std::string> names)
274       TF_LOCKS_EXCLUDED(loading_result_cache_mu_);
275 
276   // Runs `func` with the given inputs, and outputs the result.
277   tensorflow::Status RunInternal(const RunOptions& run_options,
278                                  absl::string_view signature_name,
279                                  const tfrt::Function& func,
280                                  absl::Span<const tensorflow::Tensor> inputs,
281                                  absl::Span<const tensorflow::Tensor> captures,
282                                  std::vector<tensorflow::Tensor>* outputs,
283                                  tfrt::ResourceContext* resource_context);
284 
285   Options options_;
286   // `meta_graph_def_` only contains metadata of the model. The graph_def field
287   // is removed.
288   //
289   // TODO(b/191931702): We should only keep content that are actually used
290   // (eg. SignatureDefs), instead of keeping the whole saved model, to avoid
291   // unnecessary memory usage.
292   tensorflow::MetaGraphDef meta_graph_def_;
293   tfrt::BefBuffer bef_;
294   tfrt::RCReference<tfrt::BEFFile> bef_file_;
295   tfrt::RequestDeadlineTracker req_deadline_tracker_;
296   absl::flat_hash_map<std::string, internal::Signature> signatures_;
297   std::unique_ptr<FallbackState> fallback_state_;
298   // TODO(b/178227859): Change the hardcoding of this specific TPU resource
299   // (TpuModelResource) to a general and plugable interface.
300   std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource_;
301   std::unique_ptr<tfrt::ResourceContext> resource_context_;
302   tensorflow::mutex loading_result_cache_mu_;
303   // For pointer stability of values in `absl::flat_hash_map<>`, additional
304   // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.)
305   absl::flat_hash_map<std::string /*joined_name*/,
306                       std::unique_ptr<LoadingResult>>
307       loading_result_cache_ TF_GUARDED_BY(loading_result_cache_mu_);
308   std::unique_ptr<GraphExecutor> graph_executor_;
309   bool lazy_loading_enabled_ = false;
310 };
311 
312 }  // namespace tfrt_stub
313 }  // namespace tensorflow
314 
315 namespace tfrt {
316 
317 using SavedModel = ::tensorflow::tfrt_stub::SavedModel;
318 using SavedModelImpl = ::tensorflow::tfrt_stub::SavedModelImpl;
319 using TensorSpec = ::tensorflow::tfrt_stub::TensorSpec;
320 using FunctionMetadata = ::tensorflow::tfrt_stub::FunctionMetadata;
321 
322 namespace internal {
323 using Signature = ::tensorflow::tfrt_stub::internal::Signature;
324 }
325 
326 }  // namespace tfrt
327 
328 #endif  // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_
329