1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.h"
16
17 #include <assert.h>
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
27 #include "tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h"
28 #include "tfrt/common/compat/eigen/thread_pool_device.h" // from @tf_runtime
29 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
30 #include "tfrt/host_context/async_value.h" // from @tf_runtime
31 #include "tfrt/host_context/host_context.h" // from @tf_runtime
32 #include "tfrt/support/forward_decls.h" // from @tf_runtime
33 #include "tfrt/support/ref_count.h" // from @tf_runtime
34
35 namespace tensorflow {
36 namespace tfd {
37
38 namespace {
39 using tensorflow::KernelFallbackTensor;
40 using tfrt::AsyncValue;
41 using tfrt::RCReference;
42 } // namespace
43
SetError(const tfrt::ExecutionContext & exec_ctx,llvm::SmallVector<RCReference<AsyncValue>,4> * results,tfrt::string_view message)44 void SetError(const tfrt::ExecutionContext& exec_ctx,
45 llvm::SmallVector<RCReference<AsyncValue>, 4>* results,
46 tfrt::string_view message) {
47 RCReference<tfrt::ErrorAsyncValue> error = EmitErrorAsync(exec_ctx, message);
48 for (auto& result : *results) {
49 result->SetError(error->GetError());
50 }
51 }
52
KernelFallbackExecute(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::ArrayRef<AsyncValue * > arguments,tfrt::MutableArrayRef<RCReference<AsyncValue>> results,const tfrt::OpAttrsRef & attrs,KernelFallbackOutputType output_type)53 bool KernelFallbackExecute(
54 const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
55 tfrt::ArrayRef<AsyncValue*> arguments,
56 tfrt::MutableArrayRef<RCReference<AsyncValue>> results,
57 const tfrt::OpAttrsRef& attrs, KernelFallbackOutputType output_type) {
58 // Remove tf. prefix.
59 op_name.consume_front("tf.");
60 std::string op_name_str = op_name.str();
61
62 llvm::SmallVector<RCReference<AsyncValue>, 4> inputs;
63 inputs.reserve(arguments.size());
64 for (AsyncValue* input : arguments) {
65 inputs.push_back(FormRef(input));
66 }
67 llvm::SmallVector<RCReference<AsyncValue>, 4> outputs(results.begin(),
68 results.end());
69
70 // Always run TFRTOpKernel::Compute on the blocking thread pool to
71 // avoid deadlock. Many TF kernels block until their intra-op closures
72 // complete.
73 bool work_enqueued = EnqueueBlockingWork(
74 exec_ctx.host(),
75 [exec_ctx, inputs = std::move(inputs), outputs = std::move(outputs),
76 op_name_str = std::move(op_name_str), attrs = attrs.freeze(),
77 output_type = output_type]() mutable {
78 TFRTOpKernelConstruction op_kernel_construction(attrs);
79 std::unique_ptr<TFRTOpKernel> op =
80 tfrt_forwarding_kernel_factories->CreateKernel(
81 op_name_str, &op_kernel_construction);
82
83 // Forward kernel construction error.
84 if (op_kernel_construction.error().has_value()) {
85 SetError(exec_ctx, &outputs,
86 op_kernel_construction.error().getValue());
87 return;
88 }
89
90 const TFRTOpMeta* op_meta =
91 tfrt_forwarding_op_meta_map->GetOpMeta(op_name_str);
92 if (op_meta == nullptr) {
93 SetError(exec_ctx, &outputs,
94 tfrt::StrCat("No TFRTOpMeta for op_name ", op_name_str));
95 return;
96 }
97
98 TFRTOpKernelContext op_kernel_context(inputs, outputs.size(), op_meta,
99 exec_ctx.host());
100 op->Compute(&op_kernel_context);
101
102 // Forward the context's error or outputs to raii_frame.
103 if (op_kernel_context.error().has_value()) {
104 SetError(exec_ctx, &outputs, op_kernel_context.error().getValue());
105 return;
106 } else {
107 for (int i = 0, e = outputs.size(); i != e; ++i) {
108 // Expected result could be either a tensorflow::Tensor
109 // (in case we call kernel directly), or KernelFallbackTensor
110 // (if called from OpHandler).
111 if (output_type == KernelFallbackOutputType::TENSOR) {
112 outputs[i]->emplace<tensorflow::Tensor>(
113 op_kernel_context.output(i));
114 } else {
115 assert(output_type ==
116 KernelFallbackOutputType::KERNEL_FALLBACK_TENSOR);
117 outputs[i]->emplace<KernelFallbackTensor>(
118 KernelFallbackTensor::Create(op_kernel_context.output(i)));
119 }
120 }
121 }
122 });
123
124 return work_enqueued;
125 }
126 } // namespace tfd
127 } // namespace tensorflow
128