1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements kernels for running TFRT ops/kernels via TF eager
17 // execution.
18 
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
20 
21 #include "absl/strings/str_split.h"
22 #include "absl/synchronization/mutex.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "tensorflow/c/eager/abstract_operation.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/tf_datatype.h"
30 #include "tensorflow/c/tf_tensor_internal.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/eager/context.h"
34 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
35 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/protobuf/error_codes.pb.h"
42 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
43 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
44 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
45 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
46 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
47 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
48 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
49 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
50 #include "tensorflow/core/runtime_fallback/util/type_util.h"
51 #include "tensorflow/core/tfrt/utils/error_util.h"
52 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
53 #include "tensorflow/core/tfrt/utils/tensor_util.h"
54 #include "tfrt/cpu/core_runtime/cpu_op_handler.h"  // from @tf_runtime
55 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
56 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
57 #include "tfrt/core_runtime/execute_op_impl.h"  // from @tf_runtime
58 #include "tfrt/core_runtime/op_attr_type.h"  // from @tf_runtime
59 #include "tfrt/core_runtime/tensor_handle.h"  // from @tf_runtime
60 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
61 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
62 #include "tfrt/host_context/attribute_utils.h"  // from @tf_runtime
63 #include "tfrt/host_context/device.h"  // from @tf_runtime
64 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
65 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
66 #include "tfrt/host_context/host_buffer.h"  // from @tf_runtime
67 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
68 #include "tfrt/host_context/kernel_frame.h"  // from @tf_runtime
69 #include "tfrt/host_context/kernel_utils.h"  // from @tf_runtime
70 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
71 #include "tfrt/host_context/sync_kernel_frame.h"  // from @tf_runtime
72 #include "tfrt/support/error_util.h"  // from @tf_runtime
73 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
74 #include "tfrt/support/ref_count.h"  // from @tf_runtime
75 #include "tfrt/tensor/conversion_registry.h"  // from @tf_runtime
76 #include "tfrt/tensor/dense_host_tensor.h"  // from @tf_runtime
77 #include "tfrt/tensor/scalar_host_tensor.h"  // from @tf_runtime
78 #include "tfrt/tensor/string_host_tensor.h"  // from @tf_runtime
79 #include "tfrt/tensor/tensor_serialize_utils.h"  // from @tf_runtime
80 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
81 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
82 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
83 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
84 #include "tensorflow/core/protobuf/config.pb.h"
85 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_gpu_allocator.h"
86 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
87 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88 
89 namespace tensorflow {
90 namespace tfd {
91 namespace {
92 constexpr char kHostContextPtrAttrName[] = "host_ptr";
93 constexpr char kDefaultCpuDevice[] =
94     "/job:localhost/replica:0/task:0/device:CPU:0";
95 
96 }  // namespace
97 
98 using tfrt::AggregateAttr;
99 using tfrt::Argument;
100 using tfrt::AsyncValue;
101 using tfrt::AsyncValueRef;
102 using tfrt::BEFAttributeType;
103 using tfrt::Chain;
104 using tfrt::DenseAttr;
105 using tfrt::DenseHostTensor;
106 using tfrt::ExecutionContext;
107 using tfrt::Expected;
108 using tfrt::FuncAttr;
109 using tfrt::HostBuffer;
110 using tfrt::HostContext;
111 using tfrt::KernelErrorHandler;
112 using tfrt::OpAttrs;
113 using tfrt::OpAttrsRawEntry;
114 using tfrt::OpAttrsRef;
115 using tfrt::OpAttrType;
116 using tfrt::raw_ostream;
117 using tfrt::RCReference;
118 using tfrt::RemainingArguments;
119 using tfrt::RemainingAttributes;
120 using tfrt::RemainingResults;
121 using tfrt::Result;
122 using tfrt::ShapeAttr;
123 using tfrt::string_view;
124 using tfrt::StringAttr;
125 using tfrt::StringAttribute;
126 using tfrt::Tensor;
127 using tfrt::TensorShape;
128 
129 #define TFD_REPORT_AND_RETURN_IF_ERROR(handler, status) \
130   if (!status.ok()) {                                   \
131     handler.ReportError(status.error_message());        \
132     return;                                             \
133   }
134 
135 // Create RuntimeFallbackTensor from tensorflow::TensorHandle.
136 // Takes ownership of TensorHandle.
CreateRuntimeFallbackTensor(TensorHandle * handle,HostContext * host)137 static AsyncValueRef<RuntimeFallbackTensor> CreateRuntimeFallbackTensor(
138     TensorHandle* handle, HostContext* host) {
139   OwnedTensorHandle th(handle);
140   int rank;
141   tensorflow::Status status = th->NumDims(&rank);
142   if (!status.ok())
143     return tfrt::MakeErrorAsyncValueRef(
144         host, tfrt::StrCat("error getting rank from TF tensor handle: ",
145                            status.error_message()));
146 
147   llvm::SmallVector<tfrt::Index, 4> dims;
148   for (auto i = 0; i < rank; ++i) {
149     int64_t dim;
150     status = th->Dim(i, &dim);
151     if (!status.ok())
152       return tfrt::MakeErrorAsyncValueRef(
153           host, tfrt::StrCat("error getting dimension from TFE tensor handle: ",
154                              status.error_message()));
155     dims.push_back(dim);
156   }
157 
158   TensorShape shape{dims};
159   DataType dtype = th->DataType();
160   return tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
161       host, shape, GetTfrtDtype(dtype), std::move(th));
162 }
163 
164 // Kernel for moving DHT to RuntimeFallbackTensor. Note that the buffer of the
165 // argument dht is moved to return RuntimeFallbackTensor.
166 //
167 // Example usage in MLIR:
168 //
169 // %tft, %c2 = "tfd.move_dht_to_tft"(%dht, %c1) :
170 //   (!dht.dense_host_tensor.i32.2, !hex.chain) -> (!tfd.tf_tensor, !hex.chain)
TfdMoveDHTToTFT(Argument<DenseHostTensor> dht,Argument<Chain> in_chain,const ExecutionContext & exec_ctx)171 static std::pair<RuntimeFallbackTensor, Chain> TfdMoveDHTToTFT(
172     Argument<DenseHostTensor> dht, Argument<Chain> in_chain,
173     const ExecutionContext& exec_ctx) {
174   return std::make_pair(
175       MoveDHTToRuntimeFallbackTensor(std::move(dht.get()), exec_ctx.host()),
176       in_chain.get());
177 }
178 
179 // Kernel for converting DHT to RuntimeFallbackTensor.
180 //
181 // Example usage in MLIR:
182 //
183 // %dht, %c2 = "tfd.convert_tft_to_dht"(%tft, %c1) :
184 //   (!tfd.tf_tensor,!hex.chain) -> (!dht.dense_host_tensor.i32.2, !hex.chain)
TfdConvertTFTToDHT(Argument<RuntimeFallbackTensor> tft,Argument<Chain> in_chain,Result<DenseHostTensor> dht,Result<Chain> out_chain,KernelErrorHandler handler,const ExecutionContext & exec_ctx)185 static void TfdConvertTFTToDHT(Argument<RuntimeFallbackTensor> tft,
186                                Argument<Chain> in_chain,
187                                Result<DenseHostTensor> dht,
188                                Result<Chain> out_chain,
189                                KernelErrorHandler handler,
190                                const ExecutionContext& exec_ctx) {
191   dht.Set(tfrt::ConvertTensorOnHost(exec_ctx, tft.get(),
192                                     DenseHostTensor::kTensorType)
193               .ReleaseRCRef());
194   out_chain.Set(in_chain);
195 }
196 
197 // Kernel for printing RuntimeFallbackTensor.
198 //
199 // Example usage in MLIR:
200 //
201 // %c2 = "tfd.print_tft"(%tft, %c1) : (!tfd.tf_tensor, !hex.chain) -> !hex.chain
202 // TODO(fishx): Remove this kernel and reuse dht.print_tensor.
TfdPrintTFT(Argument<RuntimeFallbackTensor> tft,Argument<Chain> in_chain,Result<Chain> out_chain)203 static void TfdPrintTFT(Argument<RuntimeFallbackTensor> tft,
204                         Argument<Chain> in_chain, Result<Chain> out_chain) {
205   llvm::outs() << tft.get() << "\n";
206   llvm::outs().flush();
207   out_chain.Set(in_chain);
208 }
209 
210 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
211 
InjectTfGpuResourcesHelper(tensorflow::EagerContext * ctx)212 static tensorflow::Status InjectTfGpuResourcesHelper(
213     tensorflow::EagerContext* ctx) {
214   // Inject TF's GPU resources to TFRT GpuOpHandler.
215   // Note that this requires RuntimeFallbackOpHandler to be created and
216   // initialized before tfrt::GpuOpHandler to work.
217 
218   auto tf_gpu_process_state = tensorflow::GPUProcessState::singleton();
219   if (tf_gpu_process_state && tf_gpu_process_state->HasGPUDevice()) {
220     constexpr char gpu_device_type[] = "GPU";
221     int num_gpu = ctx->local_device_mgr()->NumDeviceType(gpu_device_type);
222     for (int gpu_ordinal = 0; gpu_ordinal < num_gpu; gpu_ordinal++) {
223       auto gpu_device_name = absl::StrCat(gpu_device_type, ":", gpu_ordinal);
224       Device* device;
225       TF_RETURN_IF_ERROR(
226           ctx->local_device_mgr()->LookupDevice(gpu_device_name, &device));
227       auto gpu_device = static_cast<tensorflow::BaseGPUDevice*>(device);
228       if (!gpu_device)
229         return tensorflow::errors::NotFound("TF BaseGPUDevice not found");
230 #if TENSORFLOW_USE_ROCM
231       static_assert(
232           false,
233           "static_cast to GpuContext and CUstream are invalid for ROCm.");
234 #endif
235       CUcontext gpu_context =
236           static_cast<stream_executor::gpu::GpuContext*>(
237               gpu_device->executor()->implementation()->GpuContextHack())
238               ->context();
239 
240       // TF GPU allocator is already created in
241       // tensorflow::DeviceFactory::AddDevices above, so this GetGPUAllocator
242       // ignores options and total_bytes passed in and retrieves allocator based
243       // on `tf_device_id`.
244       TfDeviceId tf_device_id{gpu_ordinal};
245       GPUOptions dummy_options;
246       tensorflow::Allocator* tf_allocator =
247           tf_gpu_process_state->GetGPUAllocator(dummy_options, tf_device_id,
248                                                 /*total_bytes=*/0,
249                                                 /*peer_gpu_ids=*/{});
250       if (!tf_allocator)
251         return tensorflow::errors::NotFound("TF allocator not found");
252       auto accelerator_device_info =
253           gpu_device->tensorflow_accelerator_device_info();
254       if (!accelerator_device_info)
255         return tensorflow::errors::NotFound(
256             "accelerator_device_info not found");
257 
258       tfrt::gpu::GpuResources gpu_resources;
259       gpu_resources.gpu_context = tfrt::gpu::wrapper::Context(gpu_context);
260       gpu_resources.allocator_factory =
261           CreateRuntimeFallbackGpuAllocatorFactory(tf_allocator);
262       gpu_resources.stream = tfrt::gpu::wrapper::Stream(static_cast<CUstream>(
263           accelerator_device_info->stream->implementation()->GpuStreamHack()));
264       auto platform = tfrt::gpu::wrapper::Platform::CUDA;
265       tfrt::gpu::SetTfrtGpuResources(
266           tfrt::gpu::wrapper::Device(gpu_ordinal, platform), gpu_resources);
267     }
268   }
269   return OkStatus();
270 }
271 
InjectTfGpuResources()272 tensorflow::Status InjectTfGpuResources() {
273   // TODO(zhangqiaorjc) Use more direct and low-level APIs to initialize GPU
274   // resources than using EagerContext. Note that this EagerContext is strictly
275   // locally scoped and an implementation detail of injecting GPU resources, and
276   // not is the same EagerContext set in RequestContext.
277   static bool already_injected_gpu_devices = false;
278   static absl::Mutex* mutex = new absl::Mutex();
279 
280   absl::MutexLock lock(mutex);
281   if (!already_injected_gpu_devices) {
282     tfrt::Expected<OwnedEagerContext> ctx = InitEagerContext();
283     if (!ctx) {
284       return tensorflow::errors::Internal(
285           tfrt::StrCat("error initializing eager context: ", ctx.takeError()));
286     }
287 
288     // GPU resources should be injected once per gpu ordinal.
289     TF_RETURN_IF_ERROR(InjectTfGpuResourcesHelper(ctx->get()));
290     already_injected_gpu_devices = true;
291   }
292 
293   return OkStatus();
294 }
295 
296 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
297 
298 // Kernel for initializing TF EagerContext.
299 //
300 // This kernel should be invoked at least once before any TF delegation kernels
301 // are invoked. Redundant calls to initialize the eager context are skipped.
302 //
303 // Example usage in MLIR:
304 //
305 // %c2 = "tfd.init_eager_context"(%c1): (!hex.chain) -> !hex.chain
306 //
TfdInitEagerContext(Argument<Chain> in_chain,Result<Chain> out_chain,KernelErrorHandler handler,const ExecutionContext & exec_ctx)307 static void TfdInitEagerContext(Argument<Chain> in_chain,
308                                 Result<Chain> out_chain,
309                                 KernelErrorHandler handler,
310                                 const ExecutionContext& exec_ctx) {
311   tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
312   tensorflow::tfd::EagerContextResource* eager_context_resource =
313       resource_context
314           ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
315               tensorflow::tfd::kEagerContextResourceName);
316   (void)eager_context_resource;
317 
318   // TODO(zhangqiaorjc): Inject GPU resources to GPU kernels.
319   out_chain.Set(in_chain);
320 }
321 
MoveDHTToTFTensor(DenseHostTensor && dht,HostContext * host)322 OwnedTFTensor MoveDHTToTFTensor(DenseHostTensor&& dht, HostContext* host) {
323   llvm::SmallVector<tfrt::Index, 4> dims;
324   dht.shape().GetDimensions(&dims);
325 
326   HostBuffer* host_buffer = dht.ReleaseBuffer().release();
327   auto deallocator = [](void* data, size_t len, void* arg) {
328     auto* host_buffer = reinterpret_cast<HostBuffer*>(arg);
329     host_buffer->DropRef();
330   };
331 
332   CheckBoolCompatibility();
333   OwnedTFTensor tf_tensor{
334       TF_NewTensor(static_cast<TF_DataType>(GetTfDataType(dht.dtype())),
335                    dims.data(), dims.size(), host_buffer->data(),
336                    host_buffer->size(), deallocator, host_buffer)};
337   return tf_tensor;
338 }
339 
DecodeDenseAttrToTensorInterface(const DenseAttr & dense_attr,HostContext * host,tensorflow::TensorInterface * result)340 static tensorflow::Status DecodeDenseAttrToTensorInterface(
341     const DenseAttr& dense_attr, HostContext* host,
342     tensorflow::TensorInterface* result) {
343   Expected<DenseHostTensor> dht =
344       tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
345   if (!dht)
346     return tensorflow::errors::Internal(tfrt::StrCat(
347         "cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface:",
348         dht.takeError()));
349   OwnedTFTensor tf_tensor = MoveDHTToTFTensor(std::move(*dht), host);
350   tensorflow::Tensor t;
351   TF_RETURN_IF_ERROR(TF_TensorToTensor(tf_tensor.get(), &t));
352   *result = tensorflow::TensorInterface(std::move(t));
353   return OkStatus();
354 }
355 
356 // Handle attributes.
357 //
358 // Refer to tensorflow/core/framework/attr_value.proto and
359 // tensorflow/c/eager/c_api.h.
360 //
361 // Note we currently do not support the following attribute value types:
362 // TFE_OpSetAttrFunction
363 // TFE_OpSetAttrFunctionName
PrepareAttributes(EagerOperation * eager_op,const OpAttrsRef & attrs,HostContext * host,EagerContext * eager_ctx)364 static tensorflow::Status PrepareAttributes(EagerOperation* eager_op,
365                                             const OpAttrsRef& attrs,
366                                             HostContext* host,
367                                             EagerContext* eager_ctx) {
368   tensorflow::Status status;
369   attrs.IterateEntries([eager_op, eager_ctx, status_ptr = &status, host,
370                         &attrs](const OpAttrsRawEntry& entry) {
371     // TFE does not expect a device attribute.
372     assert(strcmp(entry.name, "device") != 0);
373     if (IsUnusedAttribute(entry.name)) {
374       return;
375     } else if (entry.IsArray()) {
376       if (entry.element_count == 0) {
377         if (entry.type == OpAttrType::CHAR) {
378           // Empty string.
379           std::string empty_str;
380           *status_ptr = eager_op->SetAttrString(entry.name, empty_str.data(),
381                                                 empty_str.size());
382         } else {
383           // Empty array of other types.
384           AttrValue empty_attr_value;
385           eager_op->MutableAttrs()->Set(entry.name, empty_attr_value);
386         }
387       } else if (entry.type == OpAttrType::CHAR) {
388         string_view attr_value = attrs.GetStringAsserting(entry.name);
389         *status_ptr = eager_op->SetAttrString(entry.name, attr_value.data(),
390                                               attr_value.size());
391       } else if (entry.type == OpAttrType::FUNC) {
392         string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
393         *status_ptr = eager_op->SetAttrFunctionName(
394             entry.name, attr_value.data(), attr_value.size());
395       } else if (entry.type == OpAttrType::I64) {
396         llvm::ArrayRef<int64_t> int_array =
397             attrs.GetArrayAsserting<int64_t>(entry.name);
398         *status_ptr = eager_op->SetAttrIntList(entry.name, int_array.data(),
399                                                int_array.size());
400       } else if (entry.type == OpAttrType::F32) {
401         llvm::ArrayRef<float> float_array =
402             attrs.GetArrayAsserting<float>(entry.name);
403         *status_ptr = eager_op->SetAttrFloatList(entry.name, float_array.data(),
404                                                  float_array.size());
405       } else if (entry.type == OpAttrType::BOOL) {
406         llvm::ArrayRef<bool> bool_array =
407             attrs.GetArrayAsserting<bool>(entry.name);
408         // SetAttrBoolList expects const unsigned char*, not const bool*.
409         std::vector<unsigned char> bool_char_array(bool_array.begin(),
410                                                    bool_array.end());
411         *status_ptr = eager_op->SetAttrBoolList(
412             entry.name, bool_char_array.data(), bool_char_array.size());
413       } else if (entry.type == OpAttrType::DTYPE) {
414         const auto& op_attr = attrs.GetRawAsserting(entry.name);
415         assert(op_attr.IsArray());
416 
417         // DTypes in BEF attributes are tfrt::DType enums. So we need
418         // to convert then to tensorflow data types first.
419         auto bef_dtypes = llvm::makeArrayRef(
420             static_cast<const tfrt::DType*>(op_attr.GetData()),
421             op_attr.element_count);
422 
423         llvm::SmallVector<tensorflow::DataType, 4> tf_dtypes;
424         tf_dtypes.reserve(bef_dtypes.size());
425         for (auto bef_dtype : bef_dtypes) {
426           tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
427         }
428 
429         *status_ptr = eager_op->SetAttrTypeList(entry.name, tf_dtypes.data(),
430                                                 tf_dtypes.size());
431       } else {
432         *status_ptr =
433             tensorflow::errors::Internal("unsupported array attribute type");
434       }
435     } else {
436       if (entry.type == OpAttrType::I64) {
437         int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
438         *status_ptr = eager_op->SetAttrInt(entry.name, attr_value);
439       } else if (entry.type == OpAttrType::F32) {
440         float attr_value = attrs.GetAsserting<float>(entry.name);
441         *status_ptr = eager_op->SetAttrFloat(entry.name, attr_value);
442       } else if (entry.type == OpAttrType::BOOL) {
443         bool attr_value = attrs.GetAsserting<bool>(entry.name);
444         *status_ptr = eager_op->SetAttrBool(entry.name, attr_value);
445       } else if (entry.type == OpAttrType::DTYPE) {
446         OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
447         DataType tf_dtype = ConvertToTfDataType(op_attr_type);
448         *status_ptr = eager_op->SetAttrType(entry.name, tf_dtype);
449       } else if (entry.type == OpAttrType::SHAPE) {
450         tfrt::ShapeAttr shape_attr =
451             attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
452         if (shape_attr.HasRank()) {
453           *status_ptr = eager_op->SetAttrShape(
454               entry.name, shape_attr.GetShape().data(), shape_attr.GetRank());
455         } else {
456           *status_ptr = eager_op->SetAttrShape(entry.name, /*dims=*/nullptr,
457                                                /*num_dims=*/-1);
458         }
459       } else if (entry.type == OpAttrType::DENSE) {
460         DenseAttr dense_attr = attrs.GetAsserting<DenseAttr>(entry.name);
461         tensorflow::TensorInterface interface;
462         *status_ptr =
463             DecodeDenseAttrToTensorInterface(dense_attr, host, &interface);
464         if (!status_ptr->ok()) return;
465         *status_ptr = eager_op->SetAttrTensor(entry.name, &interface);
466       } else if (entry.type == OpAttrType::AGGREGATE) {
467         AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
468         int num_values = list_attr.GetNumElements();
469 
470         // Insert a dummy list attribute to the NodeDef if the aggregate attr
471         // is empty. This is needed because the ValidateNodeDef method checks
472         // the encoded_attr_ map for expected attributes, specified in the
473         // OpDef.
474         if (num_values == 0) {
475           // The int type is just a placeholder and doesn't matter.
476           std::vector<int> dummy_attr;
477           eager_op->MutableAttrs()->Set(
478               entry.name, gtl::ArraySlice<const int>(dummy_attr.data(), 0));
479           return;
480         }
481 
482         // It is guaranteed that items in one list attribute have the same
483         // type, though their sizes can be different. In particular,
484         // list(TensorShape) and list(Tensor) attribute types have to be
485         // encoded as AggregateAttr.
486         auto attr_base = list_attr.GetAttribute(0);
487         if (IsDataTypeAttribute(attr_base.type()) &&
488             GetDataType(attr_base.type()) == tfrt::DType::String) {
489           // Handle list(string).
490           llvm::SmallVector<const void*, 8> values;
491           llvm::SmallVector<size_t, 8> lengths;
492           values.reserve(num_values);
493           lengths.reserve(num_values);
494           for (int i = 0; i < num_values; ++i) {
495             auto string_attr = list_attr.GetAttributeOfType<StringAttr>(i);
496             values.push_back(string_attr.GetValue().data());
497             lengths.push_back(string_attr.GetValue().size());
498           }
499           *status_ptr = eager_op->SetAttrStringList(entry.name, values.data(),
500                                                     lengths.data(), num_values);
501         } else if (IsFuncAttribute(attr_base.type())) {
502           std::vector<const AbstractOperation*> funcs(num_values);
503           for (int i = 0; i < num_values; ++i) {
504             auto func_attr = list_attr.GetAttributeOfType<FuncAttr>(i);
505             // TODO(chuanhao): Creating a EagerOperation here is expensive.
506             // consider using AttrBuilder to set attribute directly.
507             ImmediateExecutionOperation* new_op = eager_ctx->CreateOperation();
508             auto func_name = func_attr.GetFunctionName();
509             *status_ptr = new_op->Reset(func_name.str().c_str(),
510                                         /*raw_device_name=*/nullptr);
511             funcs[i] = new_op;
512           }
513           *status_ptr =
514               eager_op->SetAttrFunctionList(entry.name, absl::MakeSpan(funcs));
515         } else if (attr_base.type() == BEFAttributeType::kShape) {
516           // Handle list(TensorShape).
517           llvm::SmallVector<int, 8> ranks;
518           llvm::SmallVector<const int64_t*, 8> dims;
519           ranks.reserve(num_values);
520           dims.reserve(num_values);
521           for (int i = 0; i < num_values; ++i) {
522             auto shape_attr = list_attr.GetAttributeOfType<ShapeAttr>(i);
523             if (shape_attr.HasRank()) {
524               ranks.push_back(shape_attr.GetRank());
525               dims.push_back(shape_attr.GetShape().data());
526             } else {
527               ranks.push_back(-1);
528               dims.push_back(nullptr);
529             }
530           }
531           *status_ptr = eager_op->SetAttrShapeList(entry.name, dims.data(),
532                                                    ranks.data(), num_values);
533         } else {
534           *status_ptr =
535               tensorflow::errors::Internal("unsupported list attribute type");
536         }
537       } else {
538         *status_ptr =
539             tensorflow::errors::Internal("unsupported scalar attribute type");
540       }
541     }
542   });
543   return status;
544 }
545 
CallEagerExecute(const tfrt::ExecutionContext & exec_ctx,EagerContext * eager_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<TensorHandle * > input_tensor_handles,const OpAttrsRef & attrs,llvm::MutableArrayRef<tensorflow::AbstractTensorHandle * > result_tensor_handles)546 Status CallEagerExecute(const tfrt::ExecutionContext& exec_ctx,
547                         EagerContext* eager_ctx, const char* op_name,
548                         const char* device_name,
549                         llvm::ArrayRef<TensorHandle*> input_tensor_handles,
550                         const OpAttrsRef& attrs,
551                         llvm::MutableArrayRef<tensorflow::AbstractTensorHandle*>
552                             result_tensor_handles) {
553   assert(eager_ctx != nullptr && "EagerContext is NULL");
554 
555   // Create TF EagerOperation.
556   OwnedEagerOperation eager_op{new EagerOperation(eager_ctx)};
557   TF_RETURN_IF_ERROR(eager_op->Reset(op_name, device_name));
558 
559   // Handle inputs.
560   for (TensorHandle* input_tensor : input_tensor_handles) {
561     TF_RETURN_IF_ERROR(eager_op->AddInput(input_tensor));
562   }
563 
564   // Handle attributes.
565   auto* host = exec_ctx.host();
566   TF_RETURN_IF_ERROR(PrepareAttributes(eager_op.get(), attrs, host, eager_ctx));
567 
568   int num_retvals = result_tensor_handles.size();
569   TF_RETURN_IF_ERROR(eager_op->Execute(
570       absl::MakeSpan(result_tensor_handles.data(), num_retvals), &num_retvals));
571 
572   return OkStatus();
573 }
574 
ShouldAddHostContextAttr(const char * op_name)575 static bool ShouldAddHostContextAttr(const char* op_name) {
576   // NOTE(rachelim): In the future, if more ops require this, instead of
577   // checking against a whitelist of op names, we could check whether the op
578   // contains an attribute called `host_ptr`.
579   return strcmp(op_name, "TFRTMakeIterator") == 0;
580 }
581 
582 // TODO(zhangqiaorjc): Unify implementation with RuntimeFallbackKernel.
RuntimeFallbackExecute(const tfrt::ExecutionContext & exec_ctx,EagerContext * eager_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<Tensor * > arguments,const OpAttrsRef & attrs,llvm::MutableArrayRef<RCReference<AsyncValue>> results)583 AsyncValueRef<Chain> RuntimeFallbackExecute(
584     const tfrt::ExecutionContext& exec_ctx, EagerContext* eager_ctx,
585     const char* op_name, const char* device_name,
586     llvm::ArrayRef<Tensor*> arguments, const OpAttrsRef& attrs,
587     llvm::MutableArrayRef<RCReference<AsyncValue>> results) {
588   auto emit_error = [&exec_ctx, results](const tensorflow::Status& status) {
589     // Set the correct TFRT error code according to the error propagated from
590     // runtime fallback execution.
591     auto error =
592         EmitErrorAsync(exec_ctx, status.error_message(),
593                        tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
594     // Set all results to error.
595     std::fill(results.begin(), results.end(), error);
596     return error;
597   };
598 
599   llvm::SmallVector<TensorHandle*, 4> input_tensor_handles;
600   input_tensor_handles.reserve(arguments.size());
601   for (Tensor* input_tensor : arguments) {
602     input_tensor_handles.push_back(
603         llvm::cast<RuntimeFallbackTensor>(input_tensor)->GetTensorHandle());
604   }
605 
606   int num_retvals = results.size();
607   llvm::SmallVector<tensorflow::AbstractTensorHandle*, 4> result_tensor_handles(
608       num_retvals);
609   Status status;
610   if (!ShouldAddHostContextAttr(op_name)) {
611     status =
612         CallEagerExecute(exec_ctx, eager_ctx, op_name, device_name,
613                          input_tensor_handles, attrs, result_tensor_handles);
614   } else {
615     // Wrap the HostContext pointer in an attribute. This is necessary for
616     // TF ops that require the TFRT HostContext to function. These kernels
617     // should not create their own HostContexts.
618     // TODO(rachelim): Support copying over non-host_ptr attrs, if there are
619     // any.
620     assert(attrs.GetNumEntries() == 1);
621     OpAttrs updated;
622 
623     updated.Set(kHostContextPtrAttrName,
624                 reinterpret_cast<int64_t>(exec_ctx.host()));
625     status = CallEagerExecute(
626         exec_ctx, eager_ctx, op_name, device_name, input_tensor_handles,
627         OpAttrsRef(std::move(updated)), result_tensor_handles);
628   }
629   if (!status.ok()) return emit_error(status);
630 
631   auto host = exec_ctx.host();
632   for (int i = 0; i < num_retvals; ++i) {
633     auto expected_fallback_tensor =
634         CreateRuntimeFallbackTensorFromTfTensorHandle(
635             OwnedTensorHandle{
636                 TensorHandleFromInterface(result_tensor_handles[i])},
637             host);
638     if (!expected_fallback_tensor)
639       results[i] = EmitErrorAsync(
640           exec_ctx, tfrt::StrCat(expected_fallback_tensor.takeError()));
641     else
642       results[i] = tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
643           host, std::move(*expected_fallback_tensor));
644   }
645 
646   return tfrt::GetReadyChain();
647 }
648 
RuntimeFallbackExecute(const tfrt::ExecutionContext & exec_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<Tensor * > arguments,const OpAttrsRef & attrs,llvm::MutableArrayRef<RCReference<AsyncValue>> results)649 AsyncValueRef<Chain> RuntimeFallbackExecute(
650     const tfrt::ExecutionContext& exec_ctx, const char* op_name,
651     const char* device_name, llvm::ArrayRef<Tensor*> arguments,
652     const OpAttrsRef& attrs,
653     llvm::MutableArrayRef<RCReference<AsyncValue>> results) {
654   // Get EagerContext.
655   auto eager_ctx_expected = GetEagerContext(exec_ctx);
656   if (!eager_ctx_expected) {
657     auto error = EmitErrorAsync(exec_ctx, eager_ctx_expected.takeError(),
658                                 tfrt::ErrorCode::kUnknown);
659     // Set all results to error.
660     std::fill(results.begin(), results.end(), error);
661     return std::move(error);
662   }
663   EagerContext* eager_ctx = eager_ctx_expected.get();
664 
665   return RuntimeFallbackExecute(exec_ctx, eager_ctx, op_name, device_name,
666                                 arguments, attrs, results);
667 }
668 
669 // Kernel to delegate to the current TF runtime kernel.
670 //
671 // Example usage in MLIR:
672 //
673 // %c2, %tft_c = "tfd.delegate_kernel"(%c1, %tft_a, %tft_b) {op_name = "MatMul"}
674 // : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) -> (!hex.chain,
675 // !tfd.tf_tensor)
676 // TODO(jingdong): Enqueue the TFE kernel execution as blocking task to the
677 // ConcurrentWorkQueue.
RuntimeFallbackKernel(Argument<Chain> in_chain,RemainingArguments input_tensors,Result<Chain> out_chain,RemainingResults output_tensors,StringAttribute op_name,RemainingAttributes remaining_attributes,KernelErrorHandler handler,const ExecutionContext & exec_ctx)678 static void RuntimeFallbackKernel(
679     Argument<Chain> in_chain, RemainingArguments input_tensors,
680     Result<Chain> out_chain, RemainingResults output_tensors,
681     StringAttribute op_name, RemainingAttributes remaining_attributes,
682     KernelErrorHandler handler, const ExecutionContext& exec_ctx) {
683   HostContext* host = exec_ctx.host();
684   tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
685   EagerContextResource* eager_context_resource =
686       resource_context->GetOrCreateResource<EagerContextResource>(
687           tensorflow::tfd::kEagerContextResourceName);
688   tfrt::Expected<EagerContext*> eager_ctx_expected =
689       eager_context_resource->GetTFEagerContext();
690   if (!eager_ctx_expected) {
691     handler.ReportError("eager_ctx_expected.takeError()");
692     return;
693   }
694   EagerContext* eager_ctx = eager_ctx_expected.get();
695 
696   // Construct TF EagerOperation.
697   // Need to copy op_name to a std::string to ensure the string is
698   // null-terminated.
699   std::string op_name_str = [&] {
700     auto view = op_name.get();
701     view.consume_front("tf.");
702     return view.str();
703   }();
704 
705   OwnedEagerOperation eager_op{new EagerOperation(eager_ctx)};
706   TFD_REPORT_AND_RETURN_IF_ERROR(
707       handler,
708       eager_op->Reset(op_name_str.c_str(), /*raw_device_name=*/nullptr));
709 
710   // Handle inputs.
711   for (AsyncValue* input_tensor_av : input_tensors.values()) {
712     auto input_tensor_handle =
713         input_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
714     TFD_REPORT_AND_RETURN_IF_ERROR(handler,
715                                    eager_op->AddInput(input_tensor_handle));
716   }
717 
718   // Handle TF op attributes.
719   // TODO(zhangqiaorjc): Encode TF attributes using native MLIR attribute types.
720   assert(remaining_attributes.size() % 2 == 0);
721   int num_tf_attrs = remaining_attributes.size() / 2;
722   for (int i = 0; i < num_tf_attrs; ++i) {
723     // Each TF attribute is represented as a pair of name and value strings.
724     // Make a copy for `attr_name` to ensure null-termination.
725     std::string attr_name =
726         remaining_attributes.GetStringAttribute(i * 2).str();
727     absl::string_view attr_value = ToAbslStringView(
728         remaining_attributes.GetStringAttribute(i * 2 + 1).get());
729     std::vector<absl::string_view> value_split =
730         tfd::AttrValueSplit(attr_value);
731 
732     // Handle different TF attribute types.
733     if (value_split[0] == "string") {
734       TFD_REPORT_AND_RETURN_IF_ERROR(
735           handler,
736           eager_op->SetAttrString(attr_name.c_str(), value_split[1].data(),
737                                   value_split[1].size()));
738     } else if (value_split[0] == "bool") {
739       bool bool_val;
740       TFD_REPORT_AND_RETURN_IF_ERROR(
741           handler, ParseBoolAttrValue(value_split[1], &bool_val));
742       TFD_REPORT_AND_RETURN_IF_ERROR(
743           handler, eager_op->SetAttrBool(attr_name.c_str(), bool_val));
744     } else if (value_split[0] == "int") {
745       int64_t int_val;
746       TFD_REPORT_AND_RETURN_IF_ERROR(
747           handler, ParseIntAttrValue(value_split[1], &int_val));
748       TFD_REPORT_AND_RETURN_IF_ERROR(
749           handler, eager_op->SetAttrInt(attr_name.c_str(), int_val));
750     } else if (value_split[0] == "tftensor") {
751       tensorflow::Tensor t;
752       TFD_REPORT_AND_RETURN_IF_ERROR(handler,
753                                      ParseTensorAttrValue(value_split[1], &t));
754       tensorflow::TensorInterface interface(t);
755       TFD_REPORT_AND_RETURN_IF_ERROR(
756           handler, eager_op->SetAttrTensor(attr_name.c_str(), &interface));
757     } else if (value_split[0] == "tfdtype") {
758       DataType dtype;
759       TFD_REPORT_AND_RETURN_IF_ERROR(handler,
760                                      ParseTfDataType(value_split[1], &dtype));
761       TFD_REPORT_AND_RETURN_IF_ERROR(
762           handler, eager_op->SetAttrType(attr_name.c_str(), dtype));
763     } else if (value_split[0] == "tfshape") {
764       std::vector<int64_t> dims;
765       TFD_REPORT_AND_RETURN_IF_ERROR(
766           handler, ParseTensorShapeAttrValue(value_split[1], &dims));
767       TFD_REPORT_AND_RETURN_IF_ERROR(
768           handler,
769           eager_op->SetAttrShape(attr_name.c_str(), dims.data(), dims.size()));
770     } else {
771       handler.ReportError("attribute type not yet supported");
772       return;
773     }
774   }
775 
776   // Invoke the TF EagerOperation.
777   int num_retvals = output_tensors.size();
778   llvm::SmallVector<tensorflow::AbstractTensorHandle*, 4> retvals(num_retvals);
779 
780   tensorflow::Status status = eager_op->Execute(
781       absl::MakeSpan(retvals.data(), num_retvals), &num_retvals);
782   TFD_REPORT_AND_RETURN_IF_ERROR(handler, status);
783 
784   // Handle outputs.
785   if (num_retvals != output_tensors.size()) {
786     handler.ReportError("Incorrect number of output values");
787     return;
788   }
789   for (int i = 0; i < num_retvals; ++i) {
790     OwnedTensorHandle owned_th{TensorHandleFromInterface(retvals[i])};
791     if (!owned_th) handler.ReportError("TensorHandleFromInterface failed");
792     auto fallback_tensor = CreateRuntimeFallbackTensorFromTfTensorHandle(
793         std::move(owned_th), host);
794     if (!fallback_tensor) {
795       output_tensors[i] = tfrt::MakeErrorAsyncValueRef(
796           host, tfrt::StrCat(fallback_tensor.takeError()));
797     } else {
798       output_tensors[i] =
799           tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
800               host, std::move(*fallback_tensor));
801     }
802   }
803   out_chain.Set(in_chain);
804 }
805 
EmitErrorAndSetInResults(const tfrt::ExecutionContext & exec_ctx,const tfrt::DecodedDiagnostic & error,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)806 static void EmitErrorAndSetInResults(
807     const tfrt::ExecutionContext& exec_ctx,
808     const tfrt::DecodedDiagnostic& error,
809     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
810   auto error_av = tfrt::EmitErrorAsync(exec_ctx, error.message, error.code);
811   std::fill(results.begin(), results.end(), error_av);
812 }
813 
814 // Convert the tfrt::TensorHandle to tensorflow::Tensor. `device` is the target
815 // device for the converted tensorflow::Tensor.
816 //
817 // TODO(tfrt-devs): Currently the target device can only be CPU. We need to add
818 // support for more devices.
CoreRTTensorHandleToFallbackTensorInternal(llvm::ArrayRef<tfrt::AsyncValue * > tensorhandle_args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> tf_tensor_results,tfrt::string_view device,const tfrt::ExecutionContext & exec_ctx)819 void CoreRTTensorHandleToFallbackTensorInternal(
820     llvm::ArrayRef<tfrt::AsyncValue*> tensorhandle_args,
821     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>>
822         tf_tensor_results,
823     tfrt::string_view device, const tfrt::ExecutionContext& exec_ctx) {
824   assert(tensorhandle_args.size() == tf_tensor_results.size());
825 
826   auto set_result = [&](tfrt::RCReference<tfrt::AsyncValue>& result,
827                         llvm::Expected<tensorflow::Tensor> tf_tensor) {
828     auto result_ref = tfrt::MakeUnconstructedAsyncValueRef<
829         tensorflow::tfrt_stub::FallbackTensor>();
830     if (!tf_tensor) {
831       result_ref.SetError(tfrt::StrCat(tf_tensor.takeError()));
832     } else {
833       result_ref.emplace(std::move(tf_tensor.get()));
834     }
835     result = std::move(result_ref);
836   };
837 
838   auto maybe_convert_runtime_fallback_tensor =
839       [&exec_ctx](
840           tfrt::AsyncValueRef<Tensor> tensor_avref,
841           const tfrt::Device& src_device,
842           const tfrt::Device& dst_device) -> tfrt::AsyncValueRef<tfrt::Tensor> {
843     // TODO(tfrt-devs): Remove implicit conversion in this kernel since it will
844     // have extra overheads.
845     // Convert RuntimeFallbackTensor to KernelFallbackTensor before calling
846     // into kernel fallback. That is because today kernel fallback cannot read
847     // tensor from runtime fallback tensor because we don't want kernel
848     // fallback to depend on runtime fallback.
849     assert(tensor_avref.IsAvailable());
850     assert(!tensor_avref.IsError());
851     auto& tensor = tensor_avref.get();
852     if (!tensor.IsTensorType(DenseHostTensor::kTensorType) ||
853         !src_device.IsDeviceType(tfrt::CpuDevice::kDeviceType) ||
854         !dst_device.IsDeviceType(tfrt::CpuDevice::kDeviceType)) {
855       return tfrt::ConvertTensor(exec_ctx, tensor,
856                                  /*src=*/src_device,
857                                  /*dst=*/dst_device,
858                                  KernelFallbackTensor::kTensorType);
859     }
860     return tensor_avref;
861   };
862 
863   auto dst_device =
864       exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(device);
865 
866   // Retrieve the underlying pointer of tfrt::Tensor. We don't need to do
867   // extra ownership management here because KernelFallbackExecuteCompat()
868   // will always convert it to tensorflow::Tensor which is itself refcounted.
869   for (int i = 0; i < tensorhandle_args.size(); ++i) {
870     if (!dst_device) {
871       tf_tensor_results[i] = tfrt::MakeErrorAsyncValueRef(
872           tfrt::StrCat("Failed to find device with name ", device));
873       continue;
874     }
875     auto& tensor_handle = tensorhandle_args[i]->get<tfrt::TensorHandle>();
876     assert(tensor_handle.IsDeviceAvailable());
877     assert(!tensor_handle.IsDeviceError());
878 
879     auto* tensor_av = tensor_handle.GetAsyncTensor();
880     auto tensor_avref = tfrt::AsyncValueRef<Tensor>(FormRef(tensor_av));
881 
882     auto& src_device = *tensor_handle.GetAvailableDevice();
883     AsyncValueRef<Tensor> knfb_tensor;
884     if (!tensor_av->IsAvailable()) {
885       auto ind_av = tfrt::MakeIndirectAsyncValue(exec_ctx.host());
886       knfb_tensor = AsyncValueRef<Tensor>(ind_av.CopyRef());
887       tensor_av->AndThen(
888           [tensor_avref = std::move(tensor_avref), ind_av = std::move(ind_av),
889            &src_device, dst_device = dst_device.CopyRef(),
890            maybe_convert_runtime_fallback_tensor, exec_ctx]() mutable {
891             ind_av->ForwardTo(maybe_convert_runtime_fallback_tensor(
892                 std::move(tensor_avref), src_device, *dst_device));
893           });
894     } else {
895       knfb_tensor = maybe_convert_runtime_fallback_tensor(
896           std::move(tensor_avref), src_device, *dst_device);
897     }
898 
899     if (!knfb_tensor.IsAvailable()) {
900       auto result_ref = tfrt::MakeIndirectAsyncValue(exec_ctx.host());
901       tf_tensor_results[i] = result_ref;
902       auto knfb_tensor_av = knfb_tensor.GetAsyncValue();
903       knfb_tensor_av->AndThen([knfb_tensor = std::move(knfb_tensor),
904                                result_ref = std::move(result_ref),
905                                dst_device = dst_device.CopyRef(),
906                                exec_ctx]() mutable {
907         if (knfb_tensor.IsError()) {
908           result_ref->ForwardTo(std::move(knfb_tensor));
909           return;
910         }
911         auto expected_tf_tensor =
912             TFRTTensorToTFTensor(knfb_tensor.get(), exec_ctx.host());
913         if (!expected_tf_tensor) {
914           auto error =
915               tfrt::EmitErrorAsync(exec_ctx, expected_tf_tensor.takeError());
916           result_ref->ForwardTo(std::move(error));
917         } else {
918           auto tf_tensor_ref = tfrt::MakeAvailableAsyncValueRef<
919               tensorflow::tfrt_stub::FallbackTensor>(
920               std::move(expected_tf_tensor.get()));
921           result_ref->ForwardTo(std::move(tf_tensor_ref));
922         }
923       });
924     } else {
925       set_result(tf_tensor_results[i],
926                  TFRTTensorToTFTensor(knfb_tensor.get(), exec_ctx.host()));
927     }
928   }
929 }
930 
931 // Returns true if the tensorflow::DataType is trivially copyable.
IsTriviallyCopyableTensorflowDataType(tensorflow::DataType dtype)932 static bool IsTriviallyCopyableTensorflowDataType(tensorflow::DataType dtype) {
933   static const auto* const non_trivially_copyable_dtypes =
934       new absl::flat_hash_set<tensorflow::DataType>{
935           tensorflow::DataType::DT_STRING, tensorflow::DataType::DT_RESOURCE,
936           tensorflow::DataType::DT_VARIANT};
937   return !non_trivially_copyable_dtypes->contains(dtype);
938 }
939 
ConstDenseTensor(tfrt::DenseAttr value,const tfrt::ExecutionContext & context)940 static llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstDenseTensor(
941     tfrt::DenseAttr value, const tfrt::ExecutionContext& context) {
942   auto dtype = GetTfDataType(tfrt::DType(value.dtype()));
943   // The data type must be trivially copyable so that we can use memcpy.
944   DCHECK(IsTriviallyCopyableTensorflowDataType(dtype));
945   tensorflow::Tensor tensor(dtype, tensorflow::TensorShape(value.shape()));
946   std::memcpy(tensor.data(), value.GetElements(), tensor.TotalBytes());
947   return tensorflow::tfrt_stub::FallbackTensor(tensor);
948 }
949 
ConstStringTensor(tfrt::ArrayAttr shape,tfrt::AggregateAttr value,const ExecutionContext & context)950 static llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstStringTensor(
951     tfrt::ArrayAttr shape, tfrt::AggregateAttr value,
952     const ExecutionContext& context) {
953   llvm::SmallVector<int64_t> dims;
954   auto tfrt_tensor_shape = tfrt::TensorShape(shape.GetValue<int64_t>());
955   tfrt_tensor_shape.GetDimensions(&dims);
956   tensorflow::Tensor tensor(tensorflow::DT_STRING,
957                             tensorflow::TensorShape(dims));
958   auto len = tensor.NumElements();
959   auto from = value;
960   auto to = tensor.flat<tensorflow::tstring>();
961   if (from.GetNumElements() == 1) {
962     // All elements are the same, and only one element is saved in BEF.
963     for (size_t i = 0; i < len; ++i) {
964       to(i) =
965           ToAbslStringView(from.GetAttributeOfType<StringAttr>(0).GetValue());
966     }
967   } else {
968     assert(len == from.GetNumElements());
969     for (size_t i = 0; i < len; ++i) {
970       to(i) =
971           ToAbslStringView(from.GetAttributeOfType<StringAttr>(i).GetValue());
972     }
973   }
974   return tensorflow::tfrt_stub::FallbackTensor(tensor);
975 }
976 
977 // The BEF kernel for tfrt::TensorHandle to tensorflow::Tensor conversion.
CoreRTTensorHandleToFallbackTensor(RemainingArguments args,RemainingResults results,StringAttr device,const tfrt::ExecutionContext & exec_ctx)978 void CoreRTTensorHandleToFallbackTensor(
979     RemainingArguments args, RemainingResults results, StringAttr device,
980     const tfrt::ExecutionContext& exec_ctx) {
981   tensorflow::profiler::TraceMe trace_me(
982       "corert_tensorhandle_to_fallback_tensor");
983   trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
984     return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
985   });
986 
987   CoreRTTensorHandleToFallbackTensorInternal(args.values(), results.values(),
988                                              device.GetValue(), exec_ctx);
989 }
990 
991 // Convert the tensorflow::Tensor to tfrt::TensorHandle. `device` is the device
992 // for the input tensorflow::Tensor.
993 //
994 // TODO(tfrt-devs): Currently the input device can only be CPU. We need to add
995 // support for more devices.
FallbackTensorToCoreRTTensorHandleInternal(llvm::ArrayRef<tfrt::AsyncValue * > tf_tensor_args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> tensorhandle_results,absl::string_view device,const tfrt::ExecutionContext & exec_ctx)996 static void FallbackTensorToCoreRTTensorHandleInternal(
997     llvm::ArrayRef<tfrt::AsyncValue*> tf_tensor_args,
998     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>>
999         tensorhandle_results,
1000     absl::string_view device, const tfrt::ExecutionContext& exec_ctx) {
1001   auto* host = exec_ctx.host();
1002 
1003   assert(tf_tensor_args.size() == tensorhandle_results.size());
1004   for (int i = 0; i < tf_tensor_args.size(); ++i) {
1005     auto* av = tf_tensor_args[i];
1006     auto& tf_tensor = av->get<tensorflow::tfrt_stub::FallbackTensor>().tensor();
1007     AsyncValueRef<tfrt::Tensor> kernel_fallback_tensor =
1008         tfrt::MakeAvailableAsyncValueRef<KernelFallbackTensor>(tf_tensor);
1009     auto metadata = kernel_fallback_tensor.get().metadata();
1010 
1011     tensorhandle_results[i] =
1012         tfrt::MakeAvailableAsyncValueRef<tfrt::TensorHandle>(
1013             host,
1014             host->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
1015                 {device.data(), device.size()}),
1016             metadata, std::move(kernel_fallback_tensor));
1017   }
1018 }
1019 
1020 // The BEF kernel for tensorflow::Tensor to tfrt::TensorHandle conversion.
FallbackTensorToCoreRTTensorHandle(RemainingArguments args,RemainingResults results,StringAttr device,const tfrt::ExecutionContext & exec_ctx)1021 void FallbackTensorToCoreRTTensorHandle(
1022     RemainingArguments args, RemainingResults results, StringAttr device,
1023     const tfrt::ExecutionContext& exec_ctx) {
1024   tensorflow::profiler::TraceMe trace_me(
1025       "fallback_tensor_to_corert_tensorhandle");
1026   trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
1027     return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
1028   });
1029 
1030   FallbackTensorToCoreRTTensorHandleInternal(
1031       args.values(), results.values(), ToAbslStringView(device.GetValue()),
1032       exec_ctx);
1033 }
1034 
Predicate(const tensorflow::tfrt_stub::FallbackTensor & input,const tfrt::ExecutionContext & exec_ctx)1035 static llvm::Expected<bool> Predicate(
1036     const tensorflow::tfrt_stub::FallbackTensor& input,
1037     const tfrt::ExecutionContext& exec_ctx) {
1038   const auto& tensor = input.tensor();
1039   if (TensorShapeUtils::IsScalar(tensor.shape())) {
1040     switch (tensor.dtype()) {
1041 #define CASE(T)                  \
1042   case DataTypeToEnum<T>::value: \
1043     return tensor.scalar<T>()() != 0;
1044 
1045       CASE(float);
1046       CASE(double);
1047       CASE(uint8);
1048       CASE(int8);
1049       CASE(int16);
1050       CASE(int32);
1051       CASE(int64_t);
1052       CASE(bool);
1053 #undef CASE
1054       case DT_STRING:
1055         return !tensor.scalar<tstring>()().empty();
1056       default:
1057         return tfrt::MakeStringError(DataTypeString(tensor.dtype()),
1058                                      " cannot be converted to a boolean");
1059     }
1060   }
1061 
1062   return tensor.NumElements() > 0;
1063 }
1064 
PrintFallbackTensor(const tensorflow::tfrt_stub::FallbackTensor & arg,const tfrt::Chain & ch)1065 tfrt::Chain PrintFallbackTensor(
1066     const tensorflow::tfrt_stub::FallbackTensor& arg, const tfrt::Chain& ch) {
1067   std::string message;
1068   llvm::raw_string_ostream(message) << arg.tensor().DebugString() << "\n";
1069   printf("%s", message.c_str());
1070   fflush(stdout);
1071   return tfrt::Chain();
1072 }
1073 
1074 // The legacy kernel implementation that dispatches runtime fallback operations.
1075 // Since the arguments and results are tensorflow::Tensors, internally it
1076 // does conversions between RuntimeFallbackTensor and tensorflow::Tensor.
RuntimeFallbackExecuteOp(RemainingArguments args,RemainingResults results,StringAttr device_attr,AggregateAttr op_attr_array,AggregateAttr op_func_attr_array,StringAttr op_name_attr,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const ExecutionContext & exec_ctx)1077 static void RuntimeFallbackExecuteOp(
1078     RemainingArguments args, RemainingResults results, StringAttr device_attr,
1079     AggregateAttr op_attr_array, AggregateAttr op_func_attr_array,
1080     StringAttr op_name_attr, tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
1081     const ExecutionContext& exec_ctx) {
1082   auto set_error = [&exec_ctx, results](tfrt::string_view msg) {
1083     auto error_av = EmitErrorAsync(exec_ctx, msg, tfrt::ErrorCode::kUnknown);
1084     // Set all results to error.
1085     for (int i = 0, n = results.size(); i < n; ++i) results[i] = error_av;
1086   };
1087 
1088   auto op_name = op_name_attr.GetValue();
1089   op_name.consume_front("tf.");
1090 
1091   // The device name might not be in the format expected by tensorflow. In that
1092   // case we change it to the correct format. Currently we only support CPU.
1093   //
1094   // TODO(tfrt-devs): Make sure device names passed to fallback kernels are in
1095   // the tensorflow format.
1096   std::string device_name = device_attr.GetValue().str();
1097   if (!absl::StartsWith(device_name, "/")) device_name = kDefaultCpuDevice;
1098 
1099   auto* host = exec_ctx.host();
1100 
1101   // Set up OpAttrs.
1102   tfrt::OpAttrs op_attrs;
1103   tfrt::SetUpOpAttrs(op_attr_array, &op_attrs);
1104 
1105   // Set up OpAttrs specifically for function attributes.
1106   tfrt::SetUpOpFuncAttrs(op_func_attr_array, &op_attrs);
1107 
1108   // Get EagerContext.
1109   auto eager_ctx_expected = GetEagerContext(exec_ctx);
1110   if (!eager_ctx_expected) {
1111     set_error(tfrt::StrCat(eager_ctx_expected.takeError()));
1112     return;
1113   }
1114   EagerContext* eager_ctx = eager_ctx_expected.get();
1115 
1116   // Get device.
1117   Device* device = nullptr;
1118   Status s = eager_ctx->local_device_mgr()->LookupDevice(device_name, &device);
1119   if (!s.ok()) {
1120     // The device name can be invalid in certain cases. Use default CPU device.
1121     VLOG(1) << s.error_message() << " using default CPU device.";
1122   }
1123 
1124   // First we convert tensorflow::Tensor to RuntimeFallbackTensors.
1125   llvm::SmallVector<RuntimeFallbackTensor, 4> tfrt_tensor_args;
1126   tfrt_tensor_args.reserve(args.size());
1127   for (int i = 0; i < args.size(); ++i) {
1128     auto* av = args[i];
1129     auto tf_tensor = av->get<tensorflow::Tensor>();
1130 
1131     tfrt::TensorMetadata md = tfd::GetTensorMetadata(tf_tensor);
1132     OwnedTensorHandle tensor_handle{tensorflow::TensorHandle::CreateLocalHandle(
1133         std::move(tf_tensor), /*d=*/device, /*op_device=*/device, eager_ctx)};
1134 
1135     tfrt_tensor_args.push_back(
1136         RuntimeFallbackTensor(md.shape, md.dtype, std::move(tensor_handle)));
1137   }
1138 
1139   llvm::SmallVector<tfrt::Tensor*, 4> tfrt_tensor_arg_ptrs;
1140   tfrt_tensor_arg_ptrs.reserve(args.size());
1141   for (auto& tensor : tfrt_tensor_args) tfrt_tensor_arg_ptrs.push_back(&tensor);
1142 
1143   llvm::SmallVector<RCReference<tfrt::AsyncValue>, 4> tfrt_tensor_results;
1144   tfrt_tensor_results.resize(results.size());
1145 
1146   auto chain = RuntimeFallbackExecute(
1147       exec_ctx, op_name.str().c_str(), device_name.c_str(),
1148       tfrt_tensor_arg_ptrs, tfrt::OpAttrsRef(op_attrs), tfrt_tensor_results);
1149 
1150   if (op_chain) *op_chain = chain.CopyRef();
1151 
1152   // After coreruntime returns, we check if there is any error. Currently we
1153   // assume runtime fallback execution is always synchronous.
1154   DCHECK(chain.IsAvailable());
1155   if (chain.IsError()) {
1156     EmitErrorAndSetInResults(exec_ctx, chain.GetError(), results.values());
1157     return;
1158   }
1159 
1160   // Finally we convert the runtime fallback results, which are
1161   // RuntimeFallbackTensor, back to tensorflow::Tensor that is expected by the
1162   // BEF kernel.
1163   for (int i = 0; i < results.size(); ++i) {
1164     auto& runtime_fallback_tensor =
1165         tfrt_tensor_results[i]->get<RuntimeFallbackTensor>();
1166     const tensorflow::Tensor* tf_tensor = nullptr;
1167     tensorflow::Status s =
1168         runtime_fallback_tensor.GetTensorHandle()->Tensor(&tf_tensor);
1169     DCHECK(s.ok()) << s.ToString();
1170     results[i] =
1171         tfrt::MakeAvailableAsyncValueRef<tensorflow::Tensor>(host, *tf_tensor);
1172   }
1173 }
1174 
AddRuntimeFallbackImplicitConversionKernel(Argument<tfrt::OpHandler * > op_handler,const ExecutionContext & exec_ctx)1175 Chain AddRuntimeFallbackImplicitConversionKernel(
1176     Argument<tfrt::OpHandler*> op_handler, const ExecutionContext& exec_ctx) {
1177   assert(op_handler.get()->GetName() == tfrt::CpuOpHandler::kName);
1178   tfrt::CpuOpHandler* cpu_op_handler =
1179       reinterpret_cast<tfrt::CpuOpHandler*>(op_handler.get());
1180   cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1181                                         DenseHostTensor::kTensorType);
1182   cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1183                                         tfrt::AnyScalarHostTensor::kTensorType);
1184   cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1185                                         tfrt::StringHostTensor::kTensorType);
1186   return {};
1187 }
1188 
CreateRuntimeFallbackOpHandlerKernel(Result<tfrt::OpHandler * > op_handler,StringAttribute tf_device_name,const ExecutionContext & exec_ctx)1189 void CreateRuntimeFallbackOpHandlerKernel(Result<tfrt::OpHandler*> op_handler,
1190                                           StringAttribute tf_device_name,
1191                                           const ExecutionContext& exec_ctx) {
1192   auto* runtime = tfrt::CoreRuntime::GetFromHostContext(exec_ctx.host());
1193   assert(runtime);
1194   auto op_handler_ptr =
1195       CreateRuntimeFallbackOpHandler(runtime, tf_device_name.get());
1196   assert(op_handler_ptr);
1197   op_handler.Emplace(op_handler_ptr.get());
1198 }
1199 
ConvertTFRTTensorToTFTensorHandle(tfrt::Tensor * tensor)1200 static OwnedTensorHandle ConvertTFRTTensorToTFTensorHandle(
1201     tfrt::Tensor* tensor) {
1202   if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(tensor)) {
1203     tensorflow::Tensor tensor =
1204         MoveHostBufferToTfTensor(dht->buffer(), dht->dtype(), dht->shape());
1205 
1206     return OwnedTensorHandle{
1207         tensorflow::TensorHandle::CreateLocalHandle(tensor)};
1208   }
1209 
1210   if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(tensor)) {
1211     tensorflow::Tensor tensor = CopyShtToTfTensor(*sht);
1212     return OwnedTensorHandle{
1213         tensorflow::TensorHandle::CreateLocalHandle(tensor)};
1214   }
1215 
1216   llvm_unreachable("unsupported tensor type");
1217 }
1218 
ConvertTFTensorHandleToTFRTTensor(OwnedTensorHandle tensor_handle,HostContext * host)1219 static llvm::Expected<tfrt::Value> ConvertTFTensorHandleToTFRTTensor(
1220     OwnedTensorHandle tensor_handle, HostContext* host) {
1221   tensorflow::Status status;
1222   // Resolve ensures Tensor is on host CPU.
1223   OwnedAbstractTensorInterface tensor_interface{
1224       tensor_handle->Resolve(&status)};
1225   if (!status.ok()) {
1226     return tfrt::MakeStringError("error resolving TensorHandle: ",
1227                                  status.error_message());
1228   }
1229   auto tf_dtype = tensor_interface->Type();
1230   if (tf_dtype == DT_STRING) {
1231     // TODO(tfrt-devs): Consider a more efficient way to pass string
1232     // tensors between TFRT and TF.
1233     auto string_host_tensor =
1234         CopyTfStringTensorToStringHostTensor(tensor_interface.get(), host);
1235     if (!string_host_tensor)
1236       return tfrt::MakeStringError(
1237           "error converting TF string tensor to tfrt::StringHostTensor: ",
1238           string_host_tensor.takeError());
1239     return tfrt::Value(std::move(*string_host_tensor));
1240   }
1241 
1242   tfrt::TensorMetadata metadata(GetTfrtDtype(tf_dtype),
1243                                 GetShape(tensor_interface.get()));
1244 
1245   CheckBoolCompatibility();
1246   void* data = tensor_interface->Data();
1247   size_t size = tensor_interface->ByteSize();
1248   // `tensor_interface` holds a reference on underlying Tensorflow buffer and is
1249   // held alive by HostBuffer deallocator lambda capture (a
1250   // llvm::unique_function), and it gets released when HostBuffer deallocator is
1251   // called and destroyed.
1252   auto host_buffer = HostBuffer::CreateFromExternal(
1253       data, size,
1254       [tensor_interface = std::move(tensor_interface)](void*, size_t) {});
1255 
1256   tfrt::Value value;
1257   value.emplace<DenseHostTensor>(metadata, std::move(host_buffer));
1258   return std::move(value);
1259 }
1260 
RegisterTfdDelegateKernels(tfrt::KernelRegistry * registry)1261 void RegisterTfdDelegateKernels(tfrt::KernelRegistry* registry) {
1262   registry->AddKernel("tfd.init_eager_context",
1263                       TFRT_KERNEL(TfdInitEagerContext));
1264   registry->AddKernel("tfd.delegate_kernel",
1265                       TFRT_KERNEL(RuntimeFallbackKernel));
1266   registry->AddKernel("tfd.move_dht_to_tft", TFRT_KERNEL(TfdMoveDHTToTFT));
1267   registry->AddKernel("tfd.convert_tft_to_dht",
1268                       TFRT_KERNEL(TfdConvertTFTToDHT));
1269   registry->AddKernel("tfd.print_tft", TFRT_KERNEL(TfdPrintTFT));
1270   registry->AddKernel("tfrt_fallback_async.const_dense_tensor",
1271                       TFRT_KERNEL(ConstDenseTensor));
1272   registry->AddKernel("tfrt_fallback_async.const_string_tensor",
1273                       TFRT_KERNEL(ConstStringTensor));
1274   registry->AddKernel(
1275       "tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor",
1276       TFRT_KERNEL(CoreRTTensorHandleToFallbackTensor));
1277   registry->AddKernel(
1278       "tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle",
1279       TFRT_KERNEL(FallbackTensorToCoreRTTensorHandle));
1280 
1281   // TODO(b/187106271): Move fallback kernels to fallback only libraries so that
1282   // we don't have to depend on or link in corert kernels.
1283   registry->AddKernel("tfrt_fallback_async.predicate", TFRT_KERNEL(Predicate));
1284   registry->AddKernel("tfrt_fallback_async.print_tensor",
1285                       TFRT_KERNEL(PrintFallbackTensor));
1286   registry->AddKernel("corert.create_runtime_fallback_op_handler",
1287                       TFRT_KERNEL(CreateRuntimeFallbackOpHandlerKernel));
1288   registry->AddKernel("corert.add_runtime_fallback_implicit_conversions",
1289                       TFRT_KERNEL(AddRuntimeFallbackImplicitConversionKernel));
1290 }
1291 
1292 }  // namespace tfd
1293 }  // namespace tensorflow
1294