1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ 16 #define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ 17 18 #include <functional> 19 #include <limits> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/strings/string_view.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/core/framework/graph.pb.h" 31 #include "tensorflow/core/platform/thread_annotations.h" 32 #include "tensorflow/core/protobuf/meta_graph.pb.h" 33 #include "tensorflow/core/tfrt/fallback/fallback_state.h" 34 #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" 35 #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" 36 #include "tensorflow/core/tfrt/runtime/runtime.h" 37 #include "tensorflow/core/tfrt/tpu/tpu_resources.h" // NOLINT(unused-includes): For tfrt::tpu::TpuModelResource 38 #include "tfrt/host_context/function.h" // from @tf_runtime 39 #include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime 40 #include "tfrt/host_context/resource_context.h" // from @tf_runtime 41 42 namespace tfrt { 43 class BEFFile; 44 class HostContext; 45 } // namespace tfrt 46 47 namespace tensorflow { 48 namespace tfrt_stub { 49 50 // TODO(tfrt-dev): Replace tfrt::TensorSpec with tensorflow::TensorSpec once the 51 // latter is checked in. 52 struct TensorSpec { 53 tensorflow::DataType dtype; 54 tensorflow::PartialTensorShape shape; 55 TensorSpecTensorSpec56 explicit TensorSpec(tensorflow::DataType dtype) : dtype(dtype) {} TensorSpecTensorSpec57 TensorSpec(tensorflow::DataType dtype, tensorflow::PartialTensorShape shape) 58 : dtype(dtype), shape(std::move(shape)) {} 59 }; 60 61 inline bool operator==(const TensorSpec& a, const TensorSpec& b) { 62 return a.dtype == b.dtype && a.shape.IsIdenticalTo(b.shape); 63 } 64 65 namespace internal { 66 67 struct Signature { 68 std::vector<tensorflow::Tensor> captures; 69 70 // The following three fields should have the same size. 71 std::vector<std::string> input_names; 72 std::vector<TensorSpec> input_specs; 73 std::vector<std::string> input_devices; 74 75 // The following two fields should have the same size. 76 std::vector<std::string> output_names; 77 std::vector<TensorSpec> output_specs; 78 }; 79 80 } // namespace internal 81 82 class FunctionMetadata { 83 public: FunctionMetadata(const internal::Signature * signature)84 explicit FunctionMetadata(const internal::Signature* signature) 85 : signature_(signature) { 86 assert(signature); 87 } 88 GetInputNames()89 const std::vector<std::string>& GetInputNames() const { 90 return signature_->input_names; 91 } 92 GetInputSpecs()93 const std::vector<TensorSpec>& GetInputSpecs() const { 94 return signature_->input_specs; 95 } 96 GetOutputNames()97 const std::vector<std::string>& GetOutputNames() const { 98 return signature_->output_names; 99 } 100 GetOutputSpecs()101 const std::vector<TensorSpec>& GetOutputSpecs() const { 102 return signature_->output_specs; 103 } 104 105 private: 106 friend class SavedModelImpl; 107 108 const internal::Signature* signature_ = nullptr; 109 }; 110 111 // SavedModel represents the in-memory states (graphs and variables) loaded from 112 // a tensorflow saved model directory. 113 class SavedModel { 114 public: 115 struct Options { OptionsOptions116 explicit Options(const Runtime* rt) : graph_execution_options(rt) {} 117 118 // If the number of signagures is greater than the threshold, the loading of 119 // any signature (or signature combination) will be deferred until the first 120 // corresponding invocationof running. Otherwise, the individual signatures 121 // will be loaded along with the saved model. 122 int32_t lazy_loading_threshold = std::numeric_limits<int32_t>::max(); 123 124 // If true, we'll attempt to find MLArchive within the given loading path. 125 // If not found, will use the path as a normal SavedModel directory. 126 bool maybe_load_from_mla = false; 127 128 GraphExecutionOptions graph_execution_options; 129 }; 130 131 // Per-request options. 132 using RunOptions = GraphExecutionRunOptions; 133 SavedModel(const Runtime * runtime)134 explicit SavedModel(const Runtime* runtime) : runtime_(runtime) { 135 DCHECK(runtime_); 136 } 137 virtual ~SavedModel(); 138 runtime()139 const Runtime& runtime() const { 140 DCHECK(runtime_); 141 return *runtime_; 142 } 143 tfrt::HostContext* GetHostContext() const; 144 145 // Returns meta graph def. Note that the graph_def field in the MetaGraphDef 146 // has already been removed. 147 // 148 // TODO(b/191931702): Change the method to return SignatureDefs instead. 149 virtual const tensorflow::MetaGraphDef& GetMetaGraphDef() const = 0; 150 151 // Returns all the function names. 152 virtual std::vector<std::string> GetFunctionNames() const = 0; 153 154 // Returns the `FunctionMetadata` for a function. If the function is not 155 // found, returns nullopt instead. 156 virtual std::optional<FunctionMetadata> GetFunctionMetadata( 157 absl::string_view func_name) const = 0; 158 159 // Runs the signature specified by `name`. Both `inputs` and `outputs` 160 // are all host tensors. The `outputs` must be non-null. If the returned 161 // status is non-OK, the `outputs` are invalid. 162 virtual tensorflow::Status Run(const RunOptions& run_options, 163 absl::string_view name, 164 absl::Span<const tensorflow::Tensor> inputs, 165 std::vector<tensorflow::Tensor>* outputs) = 0; 166 167 // Runs the signatures specified by `names`. Both `inputs` and `outputs` are 168 // all host tensors. The `outputs` must be non-null. If the returned status is 169 // non-OK, the `outputs` are invalid. 170 // 171 // NOTE: If the given signatures have overlapping input nodes, the input 172 // tensors for these overlapping nodes must be the same. Having different 173 // input tensors for overlapping nodes results UNDEFINED BEHAVIOR. 174 // 175 // NOTE: The input/output tensors can only be dense tensors (as opposed to 176 // sparse tensors or composite tensors). 177 virtual tensorflow::Status RunMultipleSignatures( 178 const RunOptions& run_options, absl::Span<const std::string> names, 179 absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs, 180 std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) = 0; 181 182 // Runs the graphs specified by the tensor names terminal tensors (eg. feed 183 // tensors, fetch tesnors) in the graph. 184 virtual tensorflow::Status RunByTensorNames( 185 const RunOptions& run_options, 186 absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs, 187 absl::Span<const std::string> output_tensor_names, 188 absl::Span<const std::string> target_node_names, 189 std::vector<tensorflow::Tensor>* outputs) = 0; 190 191 private: 192 const Runtime* runtime_ = nullptr; 193 }; 194 195 class SavedModelImpl final : public SavedModel { 196 public: 197 struct JoinedSignature; 198 199 // Loads all SignatureDefs in a MetaGraphDef that matches the `tags` in the 200 // tensorflow saved model from `saved_model_dir`. Refer to 201 // http://g3doc/learning/serving/g3doc/saved_model/overview.md 202 // for explanations on SavedModel. 203 // 204 // If `options.maybe_load_from_mla` is true, tries opening `saved_model_dir` 205 // as an MLA. If it's not an MLA, uses it as a normal SavedModel directory. 206 static std::unique_ptr<SavedModel> LoadSavedModel( 207 Options options, absl::string_view saved_model_dir, 208 const std::unordered_set<std::string>& tags, tensorflow::Status* status); 209 210 SavedModelImpl( 211 Options options, tensorflow::MetaGraphDef meta_graph_def, 212 tfrt::BefBuffer bef, tfrt::RCReference<tfrt::BEFFile> bef_file, 213 absl::flat_hash_map<std::string, internal::Signature> signatures, 214 std::unique_ptr<FallbackState> fallback_state, 215 std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource, 216 std::unique_ptr<tfrt::ResourceContext> resource_context, 217 std::unique_ptr<GraphExecutor> graph_executor); 218 219 ~SavedModelImpl() override; 220 221 SavedModelImpl(const SavedModelImpl&) = delete; 222 SavedModelImpl& operator=(const SavedModelImpl&) = delete; 223 224 const tensorflow::MetaGraphDef& GetMetaGraphDef() const override; 225 226 std::vector<std::string> GetFunctionNames() const override; 227 228 std::optional<FunctionMetadata> GetFunctionMetadata( 229 absl::string_view func_name) const override; 230 231 tensorflow::Status Run(const RunOptions& run_options, absl::string_view name, 232 absl::Span<const tensorflow::Tensor> inputs, 233 std::vector<tensorflow::Tensor>* outputs) override; 234 235 tensorflow::Status RunMultipleSignatures( 236 const RunOptions& run_options, absl::Span<const std::string> names, 237 absl::Span<const std::vector<tensorflow::Tensor>> multi_inputs, 238 std::vector<std::vector<tensorflow::Tensor>>* multi_outputs) override; 239 240 tensorflow::Status RunByTensorNames( 241 const RunOptions& run_options, 242 absl::Span<const std::pair<std::string, tensorflow::Tensor>> inputs, 243 absl::Span<const std::string> output_tensor_names, 244 absl::Span<const std::string> target_node_names, 245 std::vector<tensorflow::Tensor>* outputs) override; 246 247 private: 248 // The result of loading signature(s). 249 struct LoadingResult { 250 std::string name; 251 tfrt::BefBuffer bef; 252 tfrt::RCReference<tfrt::BEFFile> bef_file; 253 std::unique_ptr<tfrt::ResourceContext> resource_context; 254 }; 255 256 // Imports a subgraph as an MLIR module with the specified `input_nodes`, 257 // `output_nodes`. 258 tensorflow::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSubgraph( 259 mlir::MLIRContext* context, 260 const tensorflow::GraphImportConfig::InputArrays& input_nodes, 261 const std::vector<std::string>& output_nodes, 262 const std::vector<std::string>& target_nodes); 263 264 // Given the joined signature, loads the subgraph and returns loading result. 265 tensorflow::StatusOr< 266 std::reference_wrapper<const SavedModelImpl::LoadingResult>> 267 LoadJoinedSignature(const JoinedSignature& joined_signature) 268 TF_EXCLUSIVE_LOCKS_REQUIRED(loading_result_cache_mu_); 269 270 // Returns the loading result given the signature names. 271 tensorflow::StatusOr< 272 std::reference_wrapper<const SavedModelImpl::LoadingResult>> 273 GetOrCreateLoadingResult(absl::Span<const std::string> names) 274 TF_LOCKS_EXCLUDED(loading_result_cache_mu_); 275 276 // Runs `func` with the given inputs, and outputs the result. 277 tensorflow::Status RunInternal(const RunOptions& run_options, 278 absl::string_view signature_name, 279 const tfrt::Function& func, 280 absl::Span<const tensorflow::Tensor> inputs, 281 absl::Span<const tensorflow::Tensor> captures, 282 std::vector<tensorflow::Tensor>* outputs, 283 tfrt::ResourceContext* resource_context); 284 285 Options options_; 286 // `meta_graph_def_` only contains metadata of the model. The graph_def field 287 // is removed. 288 // 289 // TODO(b/191931702): We should only keep content that are actually used 290 // (eg. SignatureDefs), instead of keeping the whole saved model, to avoid 291 // unnecessary memory usage. 292 tensorflow::MetaGraphDef meta_graph_def_; 293 tfrt::BefBuffer bef_; 294 tfrt::RCReference<tfrt::BEFFile> bef_file_; 295 tfrt::RequestDeadlineTracker req_deadline_tracker_; 296 absl::flat_hash_map<std::string, internal::Signature> signatures_; 297 std::unique_ptr<FallbackState> fallback_state_; 298 // TODO(b/178227859): Change the hardcoding of this specific TPU resource 299 // (TpuModelResource) to a general and plugable interface. 300 std::unique_ptr<tfrt::tpu::TpuModelResource> tpu_model_resource_; 301 std::unique_ptr<tfrt::ResourceContext> resource_context_; 302 tensorflow::mutex loading_result_cache_mu_; 303 // For pointer stability of values in `absl::flat_hash_map<>`, additional 304 // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.) 305 absl::flat_hash_map<std::string /*joined_name*/, 306 std::unique_ptr<LoadingResult>> 307 loading_result_cache_ TF_GUARDED_BY(loading_result_cache_mu_); 308 std::unique_ptr<GraphExecutor> graph_executor_; 309 bool lazy_loading_enabled_ = false; 310 }; 311 312 } // namespace tfrt_stub 313 } // namespace tensorflow 314 315 namespace tfrt { 316 317 using SavedModel = ::tensorflow::tfrt_stub::SavedModel; 318 using SavedModelImpl = ::tensorflow::tfrt_stub::SavedModelImpl; 319 using TensorSpec = ::tensorflow::tfrt_stub::TensorSpec; 320 using FunctionMetadata = ::tensorflow::tfrt_stub::FunctionMetadata; 321 322 namespace internal { 323 using Signature = ::tensorflow::tfrt_stub::internal::Signature; 324 } 325 326 } // namespace tfrt 327 328 #endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ 329