xref: /aosp_15_r20/external/tensorflow/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.h"
16 
17 #include <assert.h>
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
27 #include "tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h"
28 #include "tfrt/common/compat/eigen/thread_pool_device.h"  // from @tf_runtime
29 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
30 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
31 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
32 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
33 #include "tfrt/support/ref_count.h"  // from @tf_runtime
34 
35 namespace tensorflow {
36 namespace tfd {
37 
38 namespace {
39 using tensorflow::KernelFallbackTensor;
40 using tfrt::AsyncValue;
41 using tfrt::RCReference;
42 }  // namespace
43 
SetError(const tfrt::ExecutionContext & exec_ctx,llvm::SmallVector<RCReference<AsyncValue>,4> * results,tfrt::string_view message)44 void SetError(const tfrt::ExecutionContext& exec_ctx,
45               llvm::SmallVector<RCReference<AsyncValue>, 4>* results,
46               tfrt::string_view message) {
47   RCReference<tfrt::ErrorAsyncValue> error = EmitErrorAsync(exec_ctx, message);
48   for (auto& result : *results) {
49     result->SetError(error->GetError());
50   }
51 }
52 
KernelFallbackExecute(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::ArrayRef<AsyncValue * > arguments,tfrt::MutableArrayRef<RCReference<AsyncValue>> results,const tfrt::OpAttrsRef & attrs,KernelFallbackOutputType output_type)53 bool KernelFallbackExecute(
54     const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
55     tfrt::ArrayRef<AsyncValue*> arguments,
56     tfrt::MutableArrayRef<RCReference<AsyncValue>> results,
57     const tfrt::OpAttrsRef& attrs, KernelFallbackOutputType output_type) {
58   // Remove tf. prefix.
59   op_name.consume_front("tf.");
60   std::string op_name_str = op_name.str();
61 
62   llvm::SmallVector<RCReference<AsyncValue>, 4> inputs;
63   inputs.reserve(arguments.size());
64   for (AsyncValue* input : arguments) {
65     inputs.push_back(FormRef(input));
66   }
67   llvm::SmallVector<RCReference<AsyncValue>, 4> outputs(results.begin(),
68                                                         results.end());
69 
70   // Always run TFRTOpKernel::Compute on the blocking thread pool to
71   // avoid deadlock. Many TF kernels block until their intra-op closures
72   // complete.
73   bool work_enqueued = EnqueueBlockingWork(
74       exec_ctx.host(),
75       [exec_ctx, inputs = std::move(inputs), outputs = std::move(outputs),
76        op_name_str = std::move(op_name_str), attrs = attrs.freeze(),
77        output_type = output_type]() mutable {
78         TFRTOpKernelConstruction op_kernel_construction(attrs);
79         std::unique_ptr<TFRTOpKernel> op =
80             tfrt_forwarding_kernel_factories->CreateKernel(
81                 op_name_str, &op_kernel_construction);
82 
83         // Forward kernel construction error.
84         if (op_kernel_construction.error().has_value()) {
85           SetError(exec_ctx, &outputs,
86                    op_kernel_construction.error().getValue());
87           return;
88         }
89 
90         const TFRTOpMeta* op_meta =
91             tfrt_forwarding_op_meta_map->GetOpMeta(op_name_str);
92         if (op_meta == nullptr) {
93           SetError(exec_ctx, &outputs,
94                    tfrt::StrCat("No TFRTOpMeta for op_name ", op_name_str));
95           return;
96         }
97 
98         TFRTOpKernelContext op_kernel_context(inputs, outputs.size(), op_meta,
99                                               exec_ctx.host());
100         op->Compute(&op_kernel_context);
101 
102         // Forward the context's error or outputs to raii_frame.
103         if (op_kernel_context.error().has_value()) {
104           SetError(exec_ctx, &outputs, op_kernel_context.error().getValue());
105           return;
106         } else {
107           for (int i = 0, e = outputs.size(); i != e; ++i) {
108             // Expected result could be either a tensorflow::Tensor
109             // (in case we call kernel directly), or KernelFallbackTensor
110             // (if called from OpHandler).
111             if (output_type == KernelFallbackOutputType::TENSOR) {
112               outputs[i]->emplace<tensorflow::Tensor>(
113                   op_kernel_context.output(i));
114             } else {
115               assert(output_type ==
116                      KernelFallbackOutputType::KERNEL_FALLBACK_TENSOR);
117               outputs[i]->emplace<KernelFallbackTensor>(
118                   KernelFallbackTensor::Create(op_kernel_context.output(i)));
119             }
120           }
121         }
122       });
123 
124   return work_enqueued;
125 }
126 }  // namespace tfd
127 }  // namespace tensorflow
128