xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/utils/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_H_
16 #define TENSORFLOW_CORE_TFRT_UTILS_H_
17 
18 #include <string>
19 
20 #include "absl/status/status.h"
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/tfrt/runtime/runtime.h"
27 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
28 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
29 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
30 
31 namespace tensorflow {
32 class Device;
33 class EagerContext;
34 }  // namespace tensorflow
35 
36 namespace tfrt {
37 
38 class BEFFile;
39 class ExecutionContext;
40 class HostContext;
41 
42 typedef tensorflow::gtl::InlinedVector<tfrt::DType, 4> TfrtDataTypeVector;
43 typedef tensorflow::gtl::ArraySlice<tfrt::DType> TfrtDataTypeSlice;
44 
45 // TODO(b/161370736): Have a formal method to convert between TF's and TFRT's
46 // device name. Currently TFRT adopts the suffix of TF's device name,
47 // e.g. CPU:0.
48 Expected<const char*> ConvertTfDeviceNameToTfrt(
49     const char* device_name, tensorflow::EagerContext* eager_context);
50 
51 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype);
52 
53 // Runs the runtime initialization function. A runtime initialization function
54 // is added by runtime/compiler workflow and is not present in the original
55 // savedmodel.
56 //
57 // TODO(b/178714905): We should avoid special handling on initialization by
58 // letting compiler to handle it.
59 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx,
60                                          tfrt::BEFFile* bef_file,
61                                          absl::string_view fallback_init_func);
62 
63 // Creates dummy TF devices from the input device names. Currently this method
64 // is used to create the TPU_SYSTEM device for worker server.
65 void CreateDummyTfDevices(
66     const std::vector<std::string>& device_names,
67     std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices);
68 
69 // Creates and add dummy TFRT devices from the input device names. Currently
70 // this method is used to create the TPU_SYSTEM device for worker server.
71 void AddDummyTfrtDevices(const std::vector<std::string>& device_names,
72                          tfrt::HostContext* host_ctx);
73 
74 // Creates a BEF file from a BEF buffer. `runtime` is used to provide host
75 // context for opening `bef`.
76 tensorflow::StatusOr<RCReference<tfrt::BEFFile>> CreateBefFileFromBefBuffer(
77     const tensorflow::tfrt_stub::Runtime& runtime, const tfrt::BefBuffer& bef);
78 
79 // Returns a unique integer within this process.
80 int64_t GetUniqueInt();
81 
82 // A list of macros similar to `TF_RETURN_IF_ERROR`, with additional model
83 // loading stage info.
84 #define RETURN_IF_ERROR_IN_IMPORT(...) \
85   RETURN_IF_ERROR_WITH_STAGE_INFO("GraphDef proto -> MLIR", __VA_ARGS__)
86 
87 #define RETURN_IF_ERROR_IN_COMPILE(...)                                      \
88   RETURN_IF_ERROR_WITH_STAGE_INFO(                                           \
89       "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \
90       "team",                                                                \
91       __VA_ARGS__)
92 
93 #define RETURN_IF_ERROR_IN_INIT(...) \
94   RETURN_IF_ERROR_WITH_STAGE_INFO("Initialize TFRT", __VA_ARGS__)
95 
96 #define RETURN_IF_ERROR_WITH_STAGE_INFO(stage, ...)                         \
97   do {                                                                      \
98     ::tensorflow::Status _status = (__VA_ARGS__);                           \
99     if (TF_PREDICT_FALSE(!_status.ok())) {                                  \
100       return ::tensorflow::errors::CreateWithUpdatedMessage(                \
101           _status, ::tensorflow::strings::StrCat(stage, ": ",               \
102                                                  _status.error_message())); \
103     }                                                                       \
104   } while (0)
105 
106 // A list of macros similar to `TF_ASSIGN_OR_RETURN`, with additional model
107 // loading stage info.
108 #define ASSIGN_OR_RETURN_IN_IMPORT(lhs, rexpr) \
109   ASSIGN_OR_RETURN_WITH_STAGE_INFO("GraphDef proto -> MLIR", lhs, rexpr)
110 
111 #define ASSIGN_OR_RETURN_IN_COMPILE(lhs, rexpr)                              \
112   ASSIGN_OR_RETURN_WITH_STAGE_INFO(                                          \
113       "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \
114       "team",                                                                \
115       lhs, rexpr)
116 
117 #define ASSIGN_OR_RETURN_IN_INIT(lhs, rexpr) \
118   ASSIGN_OR_RETURN_WITH_STAGE_INFO("Initialize TFRT", lhs, rexpr)
119 
120 #define ASSIGN_OR_RETURN_WITH_STAGE_INFO(stage, lhs, rexpr)                    \
121   ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL(                                       \
122       TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), stage, lhs, \
123       rexpr)
124 
125 #define ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL(statusor, stage, lhs, rexpr)    \
126   auto statusor = (rexpr);                                                    \
127   if (TF_PREDICT_FALSE(!statusor.ok())) {                                     \
128     const auto& _status = statusor.status();                                  \
129     return ::tensorflow::errors::CreateWithUpdatedMessage(                    \
130         _status,                                                              \
131         ::tensorflow::strings::StrCat(stage, ": ", _status.error_message())); \
132   }                                                                           \
133   lhs = std::move(statusor.ValueOrDie())
134 
135 }  // namespace tfrt
136 
137 #endif  // TENSORFLOW_CORE_TFRT_UTILS_H_
138