1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements RuntimeFallbackOpHandler, responsible for running TFRT
17 // ops on Tensorflow.
18
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
20
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
26 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
27 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
28 #include "tensorflow/core/runtime_fallback/util/type_util.h"
29 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
30 #include "tfrt/core_runtime/dispatch_utils.h" // from @tf_runtime
31 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
32 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime
33 #include "tfrt/core_runtime/op_invocation.h" // from @tf_runtime
34 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
35 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
36 #include "tfrt/host_context/device.h" // from @tf_runtime
37 #include "tfrt/host_context/host_context.h" // from @tf_runtime
38 #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
39 #include "tfrt/support/error_util.h" // from @tf_runtime
40 #include "tfrt/support/forward_decls.h" // from @tf_runtime
41 #include "tfrt/support/ref_count.h" // from @tf_runtime
42 #include "tfrt/support/string_util.h" // from @tf_runtime
43 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
44 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
45 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
46 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
47 #include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime
48 // TODO(b/160798174): Avoid CUDA/ROCM macro.
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tfrt/gpu/device/device.h" // from @tf_runtime
51 #include "tfrt/gpu/device/device_util.h" // from @tf_runtime
52 #include "tfrt/gpu/tensor/dense_gpu_tensor.h" // from @tf_runtime
53 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54
55 namespace tensorflow {
56 namespace tfd {
57 // TODO(tfrt-devs): Rename it.
58 class RuntimeFallbackOpHandler : public tfrt::OpHandler {
59 public:
60 ~RuntimeFallbackOpHandler() override;
61
62 llvm::Expected<tfrt::CoreRuntimeOp> MakeOp(
63 tfrt::string_view op_name) override;
64
DeviceName() const65 tfrt::string_view DeviceName() const { return device_->name(); }
66
TfDeviceName() const67 const std::string& TfDeviceName() const { return tf_device_name_; }
68
GetDeviceRef()69 tfrt::RCReference<tfrt::Device> GetDeviceRef() { return device_; }
70
71 private:
72 explicit RuntimeFallbackOpHandler(tfrt::CoreRuntime* runtime,
73 tfrt::RCReference<tfrt::Device> device,
74 const std::string& tf_device_name);
75
76 llvm::Error Initialize();
77
78 friend llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
79 tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name);
80
81 tfrt::RCReference<tfrt::Device> device_;
82 // Tensorflow device name, e.g., /device:CPU:0.
83 std::string tf_device_name_;
84 };
85
86 namespace {
87
88 using tfrt::AsyncValue;
89 using tfrt::AsyncValueRef;
90 using tfrt::Chain;
91 using tfrt::CoreRuntime;
92 using tfrt::CoreRuntimeOp;
93 using tfrt::DenseHostTensor;
94 using tfrt::ExecutionContext;
95 using tfrt::Expected;
96 using tfrt::OpAttrsRef;
97 using tfrt::OpHandler;
98 using tfrt::OpInvocation;
99 using tfrt::OpMetadataFn;
100 using tfrt::raw_ostream;
101 using tfrt::RCReference;
102 using tfrt::string_view;
103 using tfrt::Tensor;
104 using tfrt::TensorMetadata;
105
106 using RuntimeFallbackDispatchFn = AsyncValueRef<Chain> (*)(
107 const ExecutionContext& exec_ctx, const char* op_name,
108 const char* device_name, llvm::ArrayRef<Tensor*> arguments,
109 const OpAttrsRef& attrs,
110 llvm::MutableArrayRef<RCReference<AsyncValue>> results);
111
112 struct RuntimeFallbackOpEntry {
113 std::string op_name;
114 OpMetadataFn metadata_fn = nullptr;
115 // All ops use the same dispatch function.
116 RuntimeFallbackDispatchFn dispatch_fn = &RuntimeFallbackExecute;
117 };
118
GetDeviceFromFallbackTensor(const RuntimeFallbackTensor & result_tensor,const ExecutionContext & exec_ctx)119 static Expected<tfrt::RCReference<tfrt::Device>> GetDeviceFromFallbackTensor(
120 const RuntimeFallbackTensor& result_tensor,
121 const ExecutionContext& exec_ctx) {
122 tensorflow::Status status;
123 // Obtain the device. Please note that this device is probably not
124 // the device that the TensorHandle is located on. E.g. for a GPU resource
125 // its device is GPU but it is physicially located on CPU.
126 // We use this device because upper layer (e.g. distributed strategy) may
127 // use it for colocation. On the other hand, the actual device is not widely
128 // used in upper layers.
129 // In the future, if we need BackingDevice in higher layer as well, we can
130 // update c_api_tfrt layer to get it directly from tensorflow::TensorHandle.
131 const char* tf_device_name =
132 result_tensor.GetTensorHandle()->DeviceName(&status);
133 if (!status.ok()) {
134 return tfrt::MakeStringError(status.error_message());
135 }
136
137 // TODO(b/165872892): Unify device name for tests.
138 auto device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
139 tf_device_name);
140 if (!device) {
141 // Convert device name to the short form, e.g. "GPU:0".
142 const char* tfrt_device_name =
143 ConvertTfDeviceNameToTfrtDefault(tf_device_name);
144 device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
145 tfrt_device_name);
146 }
147 assert(device);
148 return std::move(device);
149 }
150
151 struct RuntimeFallbackOpHandlerTraits {
152 using InputTensorTy = Tensor;
153 using OpEntryTy = RuntimeFallbackOpEntry;
154 using OpHandlerInfoTy = RuntimeFallbackOpHandler*;
155
Dispatchtensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits156 static void Dispatch(const RuntimeFallbackOpEntry& op_entry,
157 RuntimeFallbackOpHandler* tf_op_handler,
158 llvm::ArrayRef<Tensor*> inputs, const OpAttrsRef& attrs,
159 llvm::ArrayRef<TensorMetadata> result_mds,
160 llvm::MutableArrayRef<RCReference<AsyncValue>> results,
161 AsyncValueRef<Chain>* chain,
162 const ExecutionContext& exec_ctx) {
163 // Call RuntimeFallbackExecute.
164 auto ch = op_entry.dispatch_fn(exec_ctx, op_entry.op_name.c_str(),
165 tf_op_handler->TfDeviceName().c_str(),
166 inputs, attrs, results);
167
168 if (chain) *chain = std::move(ch);
169 }
170
171 // TODO(fishx): Remove this method.
172 static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
173 tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits174 GetResultDevice(RuntimeFallbackOpHandler* tf_op_handler,
175 const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
176 const ExecutionContext& exec_ctx) {
177 if (result_tensor_av.IsAvailable()) {
178 if (result_tensor_av.IsError()) {
179 return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
180 result_tensor_av.CopyRCRef());
181 }
182 auto expected_device = GetDeviceFromFallbackTensor(
183 result_tensor_av.get<RuntimeFallbackTensor>(), exec_ctx);
184 if (!expected_device) {
185 return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
186 tfrt::MakeErrorAsyncValueRef(
187 exec_ctx.host(), tfrt::StrCat(expected_device.takeError())));
188 }
189 return std::move(expected_device.get());
190 }
191
192 auto result_device =
193 tfrt::MakeUnconstructedAsyncValueRef<tfrt::RCReference<tfrt::Device>>(
194 exec_ctx.host());
195
196 result_tensor_av.AndThen([result_tensor_av_ref = result_tensor_av.CopyRef(),
197 result_device = result_device.CopyRef(),
198 exec_ctx] {
199 assert(result_tensor_av_ref.IsAvailable());
200 if (result_tensor_av_ref.IsError()) {
201 result_device.SetError(result_tensor_av_ref.GetError());
202 }
203 auto expected_device = GetDeviceFromFallbackTensor(
204 result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx);
205 result_device.emplace(GetDeviceFromFallbackTensor(
206 result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx));
207 });
208 return std::move(result_device);
209 }
210
211 static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
212 tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits213 GetResultDevice(const RuntimeFallbackOpEntry& op_entry,
214 RuntimeFallbackOpHandler* tf_op_handler,
215 const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
216 int index, const ExecutionContext& exec_ctx) {
217 return GetResultDevice(tf_op_handler, result_tensor_av, exec_ctx);
218 }
219 };
220
221 } // namespace
222
MakeOp(string_view op_name)223 Expected<CoreRuntimeOp> RuntimeFallbackOpHandler::MakeOp(string_view op_name) {
224 // NOTE(fishx): Copying string here will cost extra overhead in graph
225 // execution. Because in current implementation, we needs to prepare the op
226 // before each executions.
227 // TODO(fishx): Avoid this heap allocation by getting op registration
228 // information from current TF.
229 RuntimeFallbackOpEntry op_entry;
230 if (!op_name.consume_front("tf."))
231 return tfrt::MakeStringError(op_name, " does not start with 'tf.'");
232 op_entry.op_name.assign(op_name.begin(), op_name.end());
233 return CoreRuntimeOp(
234 [op_entry = std::move(op_entry), this](const OpInvocation& invocation) {
235 // If the op does not have outputs, then it is expected to output an
236 // out chain.
237 bool update_chain = invocation.results.empty();
238
239 // Convert the argument tensors to RuntimeFallbackTensors.
240 for (auto& argument : invocation.arguments) {
241 argument = argument.TransferToSameDevice(
242 invocation.exec_ctx, RuntimeFallbackTensor::kTensorType);
243 }
244
245 tfrt::ExecuteOnOpHandler<RuntimeFallbackOpHandlerTraits>(
246 update_chain, invocation, std::move(op_entry), this);
247
248 // TODO(b/160798174): Avoid CUDA/ROCM macro.
249 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
250 // If the RuntimeFallbackTensor contains a tensorflow::TensorHandle
251 // that holds a GPU tensor, convert it to tfrt::DenseGpuTensor, and
252 // populate the correct device name to the result tfrt::TensorHandle.
253 //
254 // Note that if the GPU tensor contains a DataType that is not natively
255 // supported by TFRT, e.g. Resource DataType, we skip the conversion.
256 //
257 // If the RuntimeFallbackTensor's tensorflow::TensorHandle holds a CPU
258 // tensor, do not convert it to DenseHostTensor (it will be lazily
259 // converted) for performance reason.
260 for (auto& result : invocation.results) {
261 auto* host_ctx = invocation.exec_ctx.host();
262 auto* result_tensor_av = result.GetAsyncTensor();
263
264 if (!result_tensor_av->IsAvailable())
265 host_ctx->Await(FormRef(result_tensor_av));
266
267 if (result_tensor_av->IsError()) continue;
268
269 auto result_tensor_tf_th =
270 result_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
271
272 // Check if we need to convert the RuntimeFallbackTensor.
273 if (!(IsGpuTensorHandle(*result_tensor_tf_th) &&
274 IsSupportedByTFRTGpu(result_tensor_tf_th->DataType())))
275 continue;
276
277 result = result.TransferToSameDevice(
278 invocation.exec_ctx, tfrt::gpu::DenseGpuTensor::kTensorType);
279 }
280 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
281 },
282 // device and arg_tensor_type are not used in runtime fallback ops.
283 /*is_fallback=*/true, /*device=*/device_);
284 }
285
CreateRuntimeFallbackOpHandler(tfrt::CoreRuntime * runtime,tfrt::string_view tf_device_name)286 llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
287 tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name) {
288 // TODO(fishx): Remove the device field from fallback op handler.
289 std::unique_ptr<RuntimeFallbackOpHandler> op_handler(
290 new RuntimeFallbackOpHandler(
291 runtime, runtime->GetHostContext()->GetHostDeviceRef(),
292 tf_device_name.str()));
293 if (auto error = op_handler->Initialize()) {
294 return std::move(error);
295 }
296 auto op_handler_ptr = op_handler.get();
297 runtime->TakeOpHandler(std::move(op_handler));
298 return op_handler_ptr;
299 }
300
RuntimeFallbackOpHandler(CoreRuntime * runtime,tfrt::RCReference<tfrt::Device> device,const std::string & tf_device_name)301 RuntimeFallbackOpHandler::RuntimeFallbackOpHandler(
302 CoreRuntime* runtime, tfrt::RCReference<tfrt::Device> device,
303 const std::string& tf_device_name)
304 : OpHandler("tf", runtime, nullptr),
305 device_(std::move(device)),
306 tf_device_name_(tf_device_name) {}
307
~RuntimeFallbackOpHandler()308 RuntimeFallbackOpHandler::~RuntimeFallbackOpHandler() {}
309
Initialize()310 llvm::Error RuntimeFallbackOpHandler::Initialize() {
311 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
312 Status status = InjectTfGpuResources();
313 if (!status.ok()) {
314 return tfrt::MakeStringError(tfrt::StrCat("error injecting GPU resources: ",
315 status.error_message()));
316 }
317 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
318
319 return llvm::Error::success();
320 }
321
322 } // namespace tfd
323 } // namespace tensorflow
324