1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements RuntimeFallbackOpHandler, responsible for running TFRT
17 // ops on Tensorflow.
18 
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
20 
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
26 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
27 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
28 #include "tensorflow/core/runtime_fallback/util/type_util.h"
29 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/dispatch_utils.h"  // from @tf_runtime
31 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
32 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
33 #include "tfrt/core_runtime/op_invocation.h"  // from @tf_runtime
34 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
35 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
36 #include "tfrt/host_context/device.h"  // from @tf_runtime
37 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
38 #include "tfrt/host_context/kernel_utils.h"  // from @tf_runtime
39 #include "tfrt/support/error_util.h"  // from @tf_runtime
40 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
41 #include "tfrt/support/ref_count.h"  // from @tf_runtime
42 #include "tfrt/support/string_util.h"  // from @tf_runtime
43 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
44 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
45 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
46 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
47 #include "tfrt/tensor/tensor_metadata.h"  // from @tf_runtime
48 // TODO(b/160798174): Avoid CUDA/ROCM macro.
49 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 #include "tfrt/gpu/device/device.h"  // from @tf_runtime
51 #include "tfrt/gpu/device/device_util.h"  // from @tf_runtime
52 #include "tfrt/gpu/tensor/dense_gpu_tensor.h"  // from @tf_runtime
53 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54 
55 namespace tensorflow {
56 namespace tfd {
57 // TODO(tfrt-devs): Rename it.
58 class RuntimeFallbackOpHandler : public tfrt::OpHandler {
59  public:
60   ~RuntimeFallbackOpHandler() override;
61 
62   llvm::Expected<tfrt::CoreRuntimeOp> MakeOp(
63       tfrt::string_view op_name) override;
64 
DeviceName() const65   tfrt::string_view DeviceName() const { return device_->name(); }
66 
TfDeviceName() const67   const std::string& TfDeviceName() const { return tf_device_name_; }
68 
GetDeviceRef()69   tfrt::RCReference<tfrt::Device> GetDeviceRef() { return device_; }
70 
71  private:
72   explicit RuntimeFallbackOpHandler(tfrt::CoreRuntime* runtime,
73                                     tfrt::RCReference<tfrt::Device> device,
74                                     const std::string& tf_device_name);
75 
76   llvm::Error Initialize();
77 
78   friend llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
79       tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name);
80 
81   tfrt::RCReference<tfrt::Device> device_;
82   // Tensorflow device name, e.g., /device:CPU:0.
83   std::string tf_device_name_;
84 };
85 
86 namespace {
87 
88 using tfrt::AsyncValue;
89 using tfrt::AsyncValueRef;
90 using tfrt::Chain;
91 using tfrt::CoreRuntime;
92 using tfrt::CoreRuntimeOp;
93 using tfrt::DenseHostTensor;
94 using tfrt::ExecutionContext;
95 using tfrt::Expected;
96 using tfrt::OpAttrsRef;
97 using tfrt::OpHandler;
98 using tfrt::OpInvocation;
99 using tfrt::OpMetadataFn;
100 using tfrt::raw_ostream;
101 using tfrt::RCReference;
102 using tfrt::string_view;
103 using tfrt::Tensor;
104 using tfrt::TensorMetadata;
105 
106 using RuntimeFallbackDispatchFn = AsyncValueRef<Chain> (*)(
107     const ExecutionContext& exec_ctx, const char* op_name,
108     const char* device_name, llvm::ArrayRef<Tensor*> arguments,
109     const OpAttrsRef& attrs,
110     llvm::MutableArrayRef<RCReference<AsyncValue>> results);
111 
112 struct RuntimeFallbackOpEntry {
113   std::string op_name;
114   OpMetadataFn metadata_fn = nullptr;
115   // All ops use the same dispatch function.
116   RuntimeFallbackDispatchFn dispatch_fn = &RuntimeFallbackExecute;
117 };
118 
GetDeviceFromFallbackTensor(const RuntimeFallbackTensor & result_tensor,const ExecutionContext & exec_ctx)119 static Expected<tfrt::RCReference<tfrt::Device>> GetDeviceFromFallbackTensor(
120     const RuntimeFallbackTensor& result_tensor,
121     const ExecutionContext& exec_ctx) {
122   tensorflow::Status status;
123   // Obtain the device. Please note that this device is probably not
124   // the device that the TensorHandle is located on. E.g. for a GPU resource
125   // its device is GPU but it is physicially located on CPU.
126   // We use this device because upper layer (e.g. distributed strategy) may
127   // use it for colocation. On the other hand, the actual device is not widely
128   // used in upper layers.
129   // In the future, if we need BackingDevice in higher layer as well, we can
130   // update c_api_tfrt layer to get it directly from tensorflow::TensorHandle.
131   const char* tf_device_name =
132       result_tensor.GetTensorHandle()->DeviceName(&status);
133   if (!status.ok()) {
134     return tfrt::MakeStringError(status.error_message());
135   }
136 
137   // TODO(b/165872892): Unify device name for tests.
138   auto device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
139       tf_device_name);
140   if (!device) {
141     // Convert device name to the short form, e.g. "GPU:0".
142     const char* tfrt_device_name =
143         ConvertTfDeviceNameToTfrtDefault(tf_device_name);
144     device = exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
145         tfrt_device_name);
146   }
147   assert(device);
148   return std::move(device);
149 }
150 
151 struct RuntimeFallbackOpHandlerTraits {
152   using InputTensorTy = Tensor;
153   using OpEntryTy = RuntimeFallbackOpEntry;
154   using OpHandlerInfoTy = RuntimeFallbackOpHandler*;
155 
Dispatchtensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits156   static void Dispatch(const RuntimeFallbackOpEntry& op_entry,
157                        RuntimeFallbackOpHandler* tf_op_handler,
158                        llvm::ArrayRef<Tensor*> inputs, const OpAttrsRef& attrs,
159                        llvm::ArrayRef<TensorMetadata> result_mds,
160                        llvm::MutableArrayRef<RCReference<AsyncValue>> results,
161                        AsyncValueRef<Chain>* chain,
162                        const ExecutionContext& exec_ctx) {
163     // Call RuntimeFallbackExecute.
164     auto ch = op_entry.dispatch_fn(exec_ctx, op_entry.op_name.c_str(),
165                                    tf_op_handler->TfDeviceName().c_str(),
166                                    inputs, attrs, results);
167 
168     if (chain) *chain = std::move(ch);
169   }
170 
171   // TODO(fishx): Remove this method.
172   static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
173                        tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits174   GetResultDevice(RuntimeFallbackOpHandler* tf_op_handler,
175                   const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
176                   const ExecutionContext& exec_ctx) {
177     if (result_tensor_av.IsAvailable()) {
178       if (result_tensor_av.IsError()) {
179         return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
180             result_tensor_av.CopyRCRef());
181       }
182       auto expected_device = GetDeviceFromFallbackTensor(
183           result_tensor_av.get<RuntimeFallbackTensor>(), exec_ctx);
184       if (!expected_device) {
185         return tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>(
186             tfrt::MakeErrorAsyncValueRef(
187                 exec_ctx.host(), tfrt::StrCat(expected_device.takeError())));
188       }
189       return std::move(expected_device.get());
190     }
191 
192     auto result_device =
193         tfrt::MakeUnconstructedAsyncValueRef<tfrt::RCReference<tfrt::Device>>(
194             exec_ctx.host());
195 
196     result_tensor_av.AndThen([result_tensor_av_ref = result_tensor_av.CopyRef(),
197                               result_device = result_device.CopyRef(),
198                               exec_ctx] {
199       assert(result_tensor_av_ref.IsAvailable());
200       if (result_tensor_av_ref.IsError()) {
201         result_device.SetError(result_tensor_av_ref.GetError());
202       }
203       auto expected_device = GetDeviceFromFallbackTensor(
204           result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx);
205       result_device.emplace(GetDeviceFromFallbackTensor(
206           result_tensor_av_ref.get<RuntimeFallbackTensor>(), exec_ctx));
207     });
208     return std::move(result_device);
209   }
210 
211   static tfrt::Variant<tfrt::RCReference<tfrt::Device>,
212                        tfrt::AsyncValueRef<tfrt::RCReference<tfrt::Device>>>
GetResultDevicetensorflow::tfd::__anon6a8f833a0111::RuntimeFallbackOpHandlerTraits213   GetResultDevice(const RuntimeFallbackOpEntry& op_entry,
214                   RuntimeFallbackOpHandler* tf_op_handler,
215                   const tfrt::AsyncValueRef<tfrt::Tensor>& result_tensor_av,
216                   int index, const ExecutionContext& exec_ctx) {
217     return GetResultDevice(tf_op_handler, result_tensor_av, exec_ctx);
218   }
219 };
220 
221 }  // namespace
222 
MakeOp(string_view op_name)223 Expected<CoreRuntimeOp> RuntimeFallbackOpHandler::MakeOp(string_view op_name) {
224   // NOTE(fishx): Copying string here will cost extra overhead in graph
225   // execution. Because in current implementation, we needs to prepare the op
226   // before each executions.
227   // TODO(fishx): Avoid this heap allocation by getting op registration
228   // information from current TF.
229   RuntimeFallbackOpEntry op_entry;
230   if (!op_name.consume_front("tf."))
231     return tfrt::MakeStringError(op_name, " does not start with 'tf.'");
232   op_entry.op_name.assign(op_name.begin(), op_name.end());
233   return CoreRuntimeOp(
234       [op_entry = std::move(op_entry), this](const OpInvocation& invocation) {
235         // If the op does not have outputs, then it is expected to output an
236         // out chain.
237         bool update_chain = invocation.results.empty();
238 
239         // Convert the argument tensors to RuntimeFallbackTensors.
240         for (auto& argument : invocation.arguments) {
241           argument = argument.TransferToSameDevice(
242               invocation.exec_ctx, RuntimeFallbackTensor::kTensorType);
243         }
244 
245         tfrt::ExecuteOnOpHandler<RuntimeFallbackOpHandlerTraits>(
246             update_chain, invocation, std::move(op_entry), this);
247 
248 // TODO(b/160798174): Avoid CUDA/ROCM macro.
249 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
250         // If the RuntimeFallbackTensor contains a tensorflow::TensorHandle
251         // that holds a GPU tensor, convert it to tfrt::DenseGpuTensor, and
252         // populate the correct device name to the result tfrt::TensorHandle.
253         //
254         // Note that if the GPU tensor contains a DataType that is not natively
255         // supported by TFRT, e.g. Resource DataType, we skip the conversion.
256         //
257         // If the RuntimeFallbackTensor's tensorflow::TensorHandle holds a CPU
258         // tensor, do not convert it to DenseHostTensor (it will be lazily
259         // converted) for performance reason.
260         for (auto& result : invocation.results) {
261           auto* host_ctx = invocation.exec_ctx.host();
262           auto* result_tensor_av = result.GetAsyncTensor();
263 
264           if (!result_tensor_av->IsAvailable())
265             host_ctx->Await(FormRef(result_tensor_av));
266 
267           if (result_tensor_av->IsError()) continue;
268 
269           auto result_tensor_tf_th =
270               result_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
271 
272           // Check if we need to convert the RuntimeFallbackTensor.
273           if (!(IsGpuTensorHandle(*result_tensor_tf_th) &&
274                 IsSupportedByTFRTGpu(result_tensor_tf_th->DataType())))
275             continue;
276 
277           result = result.TransferToSameDevice(
278               invocation.exec_ctx, tfrt::gpu::DenseGpuTensor::kTensorType);
279         }
280 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
281       },
282       // device and arg_tensor_type are not used in runtime fallback ops.
283       /*is_fallback=*/true, /*device=*/device_);
284 }
285 
CreateRuntimeFallbackOpHandler(tfrt::CoreRuntime * runtime,tfrt::string_view tf_device_name)286 llvm::Expected<tfrt::OpHandler*> CreateRuntimeFallbackOpHandler(
287     tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name) {
288   // TODO(fishx): Remove the device field from fallback op handler.
289   std::unique_ptr<RuntimeFallbackOpHandler> op_handler(
290       new RuntimeFallbackOpHandler(
291           runtime, runtime->GetHostContext()->GetHostDeviceRef(),
292           tf_device_name.str()));
293   if (auto error = op_handler->Initialize()) {
294     return std::move(error);
295   }
296   auto op_handler_ptr = op_handler.get();
297   runtime->TakeOpHandler(std::move(op_handler));
298   return op_handler_ptr;
299 }
300 
RuntimeFallbackOpHandler(CoreRuntime * runtime,tfrt::RCReference<tfrt::Device> device,const std::string & tf_device_name)301 RuntimeFallbackOpHandler::RuntimeFallbackOpHandler(
302     CoreRuntime* runtime, tfrt::RCReference<tfrt::Device> device,
303     const std::string& tf_device_name)
304     : OpHandler("tf", runtime, nullptr),
305       device_(std::move(device)),
306       tf_device_name_(tf_device_name) {}
307 
~RuntimeFallbackOpHandler()308 RuntimeFallbackOpHandler::~RuntimeFallbackOpHandler() {}
309 
Initialize()310 llvm::Error RuntimeFallbackOpHandler::Initialize() {
311 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
312   Status status = InjectTfGpuResources();
313   if (!status.ok()) {
314     return tfrt::MakeStringError(tfrt::StrCat("error injecting GPU resources: ",
315                                               status.error_message()));
316   }
317 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
318 
319   return llvm::Error::success();
320 }
321 
322 }  // namespace tfd
323 }  // namespace tensorflow
324