1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_H_ 16 #define TENSORFLOW_CORE_TFRT_UTILS_H_ 17 18 #include <string> 19 20 #include "absl/status/status.h" 21 #include "tensorflow/core/framework/types.pb.h" 22 #include "tensorflow/core/lib/gtl/array_slice.h" 23 #include "tensorflow/core/platform/errors.h" 24 #include "tensorflow/core/platform/statusor.h" 25 #include "tensorflow/core/platform/strcat.h" 26 #include "tensorflow/core/tfrt/runtime/runtime.h" 27 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime 28 #include "tfrt/dtype/dtype.h" // from @tf_runtime 29 #include "tfrt/support/forward_decls.h" // from @tf_runtime 30 31 namespace tensorflow { 32 class Device; 33 class EagerContext; 34 } // namespace tensorflow 35 36 namespace tfrt { 37 38 class BEFFile; 39 class ExecutionContext; 40 class HostContext; 41 42 typedef tensorflow::gtl::InlinedVector<tfrt::DType, 4> TfrtDataTypeVector; 43 typedef tensorflow::gtl::ArraySlice<tfrt::DType> TfrtDataTypeSlice; 44 45 // TODO(b/161370736): Have a formal method to convert between TF's and TFRT's 46 // device name. Currently TFRT adopts the suffix of TF's device name, 47 // e.g. CPU:0. 48 Expected<const char*> ConvertTfDeviceNameToTfrt( 49 const char* device_name, tensorflow::EagerContext* eager_context); 50 51 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype); 52 53 // Runs the runtime initialization function. A runtime initialization function 54 // is added by runtime/compiler workflow and is not present in the original 55 // savedmodel. 56 // 57 // TODO(b/178714905): We should avoid special handling on initialization by 58 // letting compiler to handle it. 59 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, 60 tfrt::BEFFile* bef_file, 61 absl::string_view fallback_init_func); 62 63 // Creates dummy TF devices from the input device names. Currently this method 64 // is used to create the TPU_SYSTEM device for worker server. 65 void CreateDummyTfDevices( 66 const std::vector<std::string>& device_names, 67 std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices); 68 69 // Creates and add dummy TFRT devices from the input device names. Currently 70 // this method is used to create the TPU_SYSTEM device for worker server. 71 void AddDummyTfrtDevices(const std::vector<std::string>& device_names, 72 tfrt::HostContext* host_ctx); 73 74 // Creates a BEF file from a BEF buffer. `runtime` is used to provide host 75 // context for opening `bef`. 76 tensorflow::StatusOr<RCReference<tfrt::BEFFile>> CreateBefFileFromBefBuffer( 77 const tensorflow::tfrt_stub::Runtime& runtime, const tfrt::BefBuffer& bef); 78 79 // Returns a unique integer within this process. 80 int64_t GetUniqueInt(); 81 82 // A list of macros similar to `TF_RETURN_IF_ERROR`, with additional model 83 // loading stage info. 84 #define RETURN_IF_ERROR_IN_IMPORT(...) \ 85 RETURN_IF_ERROR_WITH_STAGE_INFO("GraphDef proto -> MLIR", __VA_ARGS__) 86 87 #define RETURN_IF_ERROR_IN_COMPILE(...) \ 88 RETURN_IF_ERROR_WITH_STAGE_INFO( \ 89 "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \ 90 "team", \ 91 __VA_ARGS__) 92 93 #define RETURN_IF_ERROR_IN_INIT(...) \ 94 RETURN_IF_ERROR_WITH_STAGE_INFO("Initialize TFRT", __VA_ARGS__) 95 96 #define RETURN_IF_ERROR_WITH_STAGE_INFO(stage, ...) \ 97 do { \ 98 ::tensorflow::Status _status = (__VA_ARGS__); \ 99 if (TF_PREDICT_FALSE(!_status.ok())) { \ 100 return ::tensorflow::errors::CreateWithUpdatedMessage( \ 101 _status, ::tensorflow::strings::StrCat(stage, ": ", \ 102 _status.error_message())); \ 103 } \ 104 } while (0) 105 106 // A list of macros similar to `TF_ASSIGN_OR_RETURN`, with additional model 107 // loading stage info. 108 #define ASSIGN_OR_RETURN_IN_IMPORT(lhs, rexpr) \ 109 ASSIGN_OR_RETURN_WITH_STAGE_INFO("GraphDef proto -> MLIR", lhs, rexpr) 110 111 #define ASSIGN_OR_RETURN_IN_COMPILE(lhs, rexpr) \ 112 ASSIGN_OR_RETURN_WITH_STAGE_INFO( \ 113 "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \ 114 "team", \ 115 lhs, rexpr) 116 117 #define ASSIGN_OR_RETURN_IN_INIT(lhs, rexpr) \ 118 ASSIGN_OR_RETURN_WITH_STAGE_INFO("Initialize TFRT", lhs, rexpr) 119 120 #define ASSIGN_OR_RETURN_WITH_STAGE_INFO(stage, lhs, rexpr) \ 121 ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL( \ 122 TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), stage, lhs, \ 123 rexpr) 124 125 #define ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL(statusor, stage, lhs, rexpr) \ 126 auto statusor = (rexpr); \ 127 if (TF_PREDICT_FALSE(!statusor.ok())) { \ 128 const auto& _status = statusor.status(); \ 129 return ::tensorflow::errors::CreateWithUpdatedMessage( \ 130 _status, \ 131 ::tensorflow::strings::StrCat(stage, ": ", _status.error_message())); \ 132 } \ 133 lhs = std::move(statusor.ValueOrDie()) 134 135 } // namespace tfrt 136 137 #endif // TENSORFLOW_CORE_TFRT_UTILS_H_ 138