1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements kernels for running TFRT ops/kernels via TF eager
17 // execution.
18
19 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h"
20
21 #include "absl/strings/str_split.h"
22 #include "absl/synchronization/mutex.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "tensorflow/c/eager/abstract_operation.h"
28 #include "tensorflow/c/eager/abstract_tensor_handle.h"
29 #include "tensorflow/c/tf_datatype.h"
30 #include "tensorflow/c/tf_tensor_internal.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/eager/context.h"
34 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
35 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/framework/node_def.pb.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/status.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/protobuf/error_codes.pb.h"
42 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
43 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
44 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
45 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
46 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h"
47 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h"
48 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
49 #include "tensorflow/core/runtime_fallback/util/tensor_util.h"
50 #include "tensorflow/core/runtime_fallback/util/type_util.h"
51 #include "tensorflow/core/tfrt/utils/error_util.h"
52 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
53 #include "tensorflow/core/tfrt/utils/tensor_util.h"
54 #include "tfrt/cpu/core_runtime/cpu_op_handler.h" // from @tf_runtime
55 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
56 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime
57 #include "tfrt/core_runtime/execute_op_impl.h" // from @tf_runtime
58 #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime
59 #include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime
60 #include "tfrt/host_context/async_value.h" // from @tf_runtime
61 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
62 #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime
63 #include "tfrt/host_context/device.h" // from @tf_runtime
64 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
65 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
66 #include "tfrt/host_context/host_buffer.h" // from @tf_runtime
67 #include "tfrt/host_context/host_context.h" // from @tf_runtime
68 #include "tfrt/host_context/kernel_frame.h" // from @tf_runtime
69 #include "tfrt/host_context/kernel_utils.h" // from @tf_runtime
70 #include "tfrt/host_context/resource_context.h" // from @tf_runtime
71 #include "tfrt/host_context/sync_kernel_frame.h" // from @tf_runtime
72 #include "tfrt/support/error_util.h" // from @tf_runtime
73 #include "tfrt/support/forward_decls.h" // from @tf_runtime
74 #include "tfrt/support/ref_count.h" // from @tf_runtime
75 #include "tfrt/tensor/conversion_registry.h" // from @tf_runtime
76 #include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
77 #include "tfrt/tensor/scalar_host_tensor.h" // from @tf_runtime
78 #include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime
79 #include "tfrt/tensor/tensor_serialize_utils.h" // from @tf_runtime
80 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
81 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
82 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
83 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
84 #include "tensorflow/core/protobuf/config.pb.h"
85 #include "tensorflow/core/runtime_fallback/runtime/runtime_fallback_gpu_allocator.h"
86 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
87 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88
89 namespace tensorflow {
90 namespace tfd {
91 namespace {
92 constexpr char kHostContextPtrAttrName[] = "host_ptr";
93 constexpr char kDefaultCpuDevice[] =
94 "/job:localhost/replica:0/task:0/device:CPU:0";
95
96 } // namespace
97
98 using tfrt::AggregateAttr;
99 using tfrt::Argument;
100 using tfrt::AsyncValue;
101 using tfrt::AsyncValueRef;
102 using tfrt::BEFAttributeType;
103 using tfrt::Chain;
104 using tfrt::DenseAttr;
105 using tfrt::DenseHostTensor;
106 using tfrt::ExecutionContext;
107 using tfrt::Expected;
108 using tfrt::FuncAttr;
109 using tfrt::HostBuffer;
110 using tfrt::HostContext;
111 using tfrt::KernelErrorHandler;
112 using tfrt::OpAttrs;
113 using tfrt::OpAttrsRawEntry;
114 using tfrt::OpAttrsRef;
115 using tfrt::OpAttrType;
116 using tfrt::raw_ostream;
117 using tfrt::RCReference;
118 using tfrt::RemainingArguments;
119 using tfrt::RemainingAttributes;
120 using tfrt::RemainingResults;
121 using tfrt::Result;
122 using tfrt::ShapeAttr;
123 using tfrt::string_view;
124 using tfrt::StringAttr;
125 using tfrt::StringAttribute;
126 using tfrt::Tensor;
127 using tfrt::TensorShape;
128
129 #define TFD_REPORT_AND_RETURN_IF_ERROR(handler, status) \
130 if (!status.ok()) { \
131 handler.ReportError(status.error_message()); \
132 return; \
133 }
134
135 // Create RuntimeFallbackTensor from tensorflow::TensorHandle.
136 // Takes ownership of TensorHandle.
CreateRuntimeFallbackTensor(TensorHandle * handle,HostContext * host)137 static AsyncValueRef<RuntimeFallbackTensor> CreateRuntimeFallbackTensor(
138 TensorHandle* handle, HostContext* host) {
139 OwnedTensorHandle th(handle);
140 int rank;
141 tensorflow::Status status = th->NumDims(&rank);
142 if (!status.ok())
143 return tfrt::MakeErrorAsyncValueRef(
144 host, tfrt::StrCat("error getting rank from TF tensor handle: ",
145 status.error_message()));
146
147 llvm::SmallVector<tfrt::Index, 4> dims;
148 for (auto i = 0; i < rank; ++i) {
149 int64_t dim;
150 status = th->Dim(i, &dim);
151 if (!status.ok())
152 return tfrt::MakeErrorAsyncValueRef(
153 host, tfrt::StrCat("error getting dimension from TFE tensor handle: ",
154 status.error_message()));
155 dims.push_back(dim);
156 }
157
158 TensorShape shape{dims};
159 DataType dtype = th->DataType();
160 return tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
161 host, shape, GetTfrtDtype(dtype), std::move(th));
162 }
163
164 // Kernel for moving DHT to RuntimeFallbackTensor. Note that the buffer of the
165 // argument dht is moved to return RuntimeFallbackTensor.
166 //
167 // Example usage in MLIR:
168 //
169 // %tft, %c2 = "tfd.move_dht_to_tft"(%dht, %c1) :
170 // (!dht.dense_host_tensor.i32.2, !hex.chain) -> (!tfd.tf_tensor, !hex.chain)
TfdMoveDHTToTFT(Argument<DenseHostTensor> dht,Argument<Chain> in_chain,const ExecutionContext & exec_ctx)171 static std::pair<RuntimeFallbackTensor, Chain> TfdMoveDHTToTFT(
172 Argument<DenseHostTensor> dht, Argument<Chain> in_chain,
173 const ExecutionContext& exec_ctx) {
174 return std::make_pair(
175 MoveDHTToRuntimeFallbackTensor(std::move(dht.get()), exec_ctx.host()),
176 in_chain.get());
177 }
178
179 // Kernel for converting DHT to RuntimeFallbackTensor.
180 //
181 // Example usage in MLIR:
182 //
183 // %dht, %c2 = "tfd.convert_tft_to_dht"(%tft, %c1) :
184 // (!tfd.tf_tensor,!hex.chain) -> (!dht.dense_host_tensor.i32.2, !hex.chain)
TfdConvertTFTToDHT(Argument<RuntimeFallbackTensor> tft,Argument<Chain> in_chain,Result<DenseHostTensor> dht,Result<Chain> out_chain,KernelErrorHandler handler,const ExecutionContext & exec_ctx)185 static void TfdConvertTFTToDHT(Argument<RuntimeFallbackTensor> tft,
186 Argument<Chain> in_chain,
187 Result<DenseHostTensor> dht,
188 Result<Chain> out_chain,
189 KernelErrorHandler handler,
190 const ExecutionContext& exec_ctx) {
191 dht.Set(tfrt::ConvertTensorOnHost(exec_ctx, tft.get(),
192 DenseHostTensor::kTensorType)
193 .ReleaseRCRef());
194 out_chain.Set(in_chain);
195 }
196
197 // Kernel for printing RuntimeFallbackTensor.
198 //
199 // Example usage in MLIR:
200 //
201 // %c2 = "tfd.print_tft"(%tft, %c1) : (!tfd.tf_tensor, !hex.chain) -> !hex.chain
202 // TODO(fishx): Remove this kernel and reuse dht.print_tensor.
TfdPrintTFT(Argument<RuntimeFallbackTensor> tft,Argument<Chain> in_chain,Result<Chain> out_chain)203 static void TfdPrintTFT(Argument<RuntimeFallbackTensor> tft,
204 Argument<Chain> in_chain, Result<Chain> out_chain) {
205 llvm::outs() << tft.get() << "\n";
206 llvm::outs().flush();
207 out_chain.Set(in_chain);
208 }
209
210 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
211
InjectTfGpuResourcesHelper(tensorflow::EagerContext * ctx)212 static tensorflow::Status InjectTfGpuResourcesHelper(
213 tensorflow::EagerContext* ctx) {
214 // Inject TF's GPU resources to TFRT GpuOpHandler.
215 // Note that this requires RuntimeFallbackOpHandler to be created and
216 // initialized before tfrt::GpuOpHandler to work.
217
218 auto tf_gpu_process_state = tensorflow::GPUProcessState::singleton();
219 if (tf_gpu_process_state && tf_gpu_process_state->HasGPUDevice()) {
220 constexpr char gpu_device_type[] = "GPU";
221 int num_gpu = ctx->local_device_mgr()->NumDeviceType(gpu_device_type);
222 for (int gpu_ordinal = 0; gpu_ordinal < num_gpu; gpu_ordinal++) {
223 auto gpu_device_name = absl::StrCat(gpu_device_type, ":", gpu_ordinal);
224 Device* device;
225 TF_RETURN_IF_ERROR(
226 ctx->local_device_mgr()->LookupDevice(gpu_device_name, &device));
227 auto gpu_device = static_cast<tensorflow::BaseGPUDevice*>(device);
228 if (!gpu_device)
229 return tensorflow::errors::NotFound("TF BaseGPUDevice not found");
230 #if TENSORFLOW_USE_ROCM
231 static_assert(
232 false,
233 "static_cast to GpuContext and CUstream are invalid for ROCm.");
234 #endif
235 CUcontext gpu_context =
236 static_cast<stream_executor::gpu::GpuContext*>(
237 gpu_device->executor()->implementation()->GpuContextHack())
238 ->context();
239
240 // TF GPU allocator is already created in
241 // tensorflow::DeviceFactory::AddDevices above, so this GetGPUAllocator
242 // ignores options and total_bytes passed in and retrieves allocator based
243 // on `tf_device_id`.
244 TfDeviceId tf_device_id{gpu_ordinal};
245 GPUOptions dummy_options;
246 tensorflow::Allocator* tf_allocator =
247 tf_gpu_process_state->GetGPUAllocator(dummy_options, tf_device_id,
248 /*total_bytes=*/0,
249 /*peer_gpu_ids=*/{});
250 if (!tf_allocator)
251 return tensorflow::errors::NotFound("TF allocator not found");
252 auto accelerator_device_info =
253 gpu_device->tensorflow_accelerator_device_info();
254 if (!accelerator_device_info)
255 return tensorflow::errors::NotFound(
256 "accelerator_device_info not found");
257
258 tfrt::gpu::GpuResources gpu_resources;
259 gpu_resources.gpu_context = tfrt::gpu::wrapper::Context(gpu_context);
260 gpu_resources.allocator_factory =
261 CreateRuntimeFallbackGpuAllocatorFactory(tf_allocator);
262 gpu_resources.stream = tfrt::gpu::wrapper::Stream(static_cast<CUstream>(
263 accelerator_device_info->stream->implementation()->GpuStreamHack()));
264 auto platform = tfrt::gpu::wrapper::Platform::CUDA;
265 tfrt::gpu::SetTfrtGpuResources(
266 tfrt::gpu::wrapper::Device(gpu_ordinal, platform), gpu_resources);
267 }
268 }
269 return OkStatus();
270 }
271
InjectTfGpuResources()272 tensorflow::Status InjectTfGpuResources() {
273 // TODO(zhangqiaorjc) Use more direct and low-level APIs to initialize GPU
274 // resources than using EagerContext. Note that this EagerContext is strictly
275 // locally scoped and an implementation detail of injecting GPU resources, and
276 // not is the same EagerContext set in RequestContext.
277 static bool already_injected_gpu_devices = false;
278 static absl::Mutex* mutex = new absl::Mutex();
279
280 absl::MutexLock lock(mutex);
281 if (!already_injected_gpu_devices) {
282 tfrt::Expected<OwnedEagerContext> ctx = InitEagerContext();
283 if (!ctx) {
284 return tensorflow::errors::Internal(
285 tfrt::StrCat("error initializing eager context: ", ctx.takeError()));
286 }
287
288 // GPU resources should be injected once per gpu ordinal.
289 TF_RETURN_IF_ERROR(InjectTfGpuResourcesHelper(ctx->get()));
290 already_injected_gpu_devices = true;
291 }
292
293 return OkStatus();
294 }
295
296 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
297
298 // Kernel for initializing TF EagerContext.
299 //
300 // This kernel should be invoked at least once before any TF delegation kernels
301 // are invoked. Redundant calls to initialize the eager context are skipped.
302 //
303 // Example usage in MLIR:
304 //
305 // %c2 = "tfd.init_eager_context"(%c1): (!hex.chain) -> !hex.chain
306 //
TfdInitEagerContext(Argument<Chain> in_chain,Result<Chain> out_chain,KernelErrorHandler handler,const ExecutionContext & exec_ctx)307 static void TfdInitEagerContext(Argument<Chain> in_chain,
308 Result<Chain> out_chain,
309 KernelErrorHandler handler,
310 const ExecutionContext& exec_ctx) {
311 tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
312 tensorflow::tfd::EagerContextResource* eager_context_resource =
313 resource_context
314 ->GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
315 tensorflow::tfd::kEagerContextResourceName);
316 (void)eager_context_resource;
317
318 // TODO(zhangqiaorjc): Inject GPU resources to GPU kernels.
319 out_chain.Set(in_chain);
320 }
321
MoveDHTToTFTensor(DenseHostTensor && dht,HostContext * host)322 OwnedTFTensor MoveDHTToTFTensor(DenseHostTensor&& dht, HostContext* host) {
323 llvm::SmallVector<tfrt::Index, 4> dims;
324 dht.shape().GetDimensions(&dims);
325
326 HostBuffer* host_buffer = dht.ReleaseBuffer().release();
327 auto deallocator = [](void* data, size_t len, void* arg) {
328 auto* host_buffer = reinterpret_cast<HostBuffer*>(arg);
329 host_buffer->DropRef();
330 };
331
332 CheckBoolCompatibility();
333 OwnedTFTensor tf_tensor{
334 TF_NewTensor(static_cast<TF_DataType>(GetTfDataType(dht.dtype())),
335 dims.data(), dims.size(), host_buffer->data(),
336 host_buffer->size(), deallocator, host_buffer)};
337 return tf_tensor;
338 }
339
DecodeDenseAttrToTensorInterface(const DenseAttr & dense_attr,HostContext * host,tensorflow::TensorInterface * result)340 static tensorflow::Status DecodeDenseAttrToTensorInterface(
341 const DenseAttr& dense_attr, HostContext* host,
342 tensorflow::TensorInterface* result) {
343 Expected<DenseHostTensor> dht =
344 tfrt::DeserializeDenseHostTensorFromDenseAttr(dense_attr, host);
345 if (!dht)
346 return tensorflow::errors::Internal(tfrt::StrCat(
347 "cannot create DenseHostTensor in DecodeDenseAttrToTensorInterface:",
348 dht.takeError()));
349 OwnedTFTensor tf_tensor = MoveDHTToTFTensor(std::move(*dht), host);
350 tensorflow::Tensor t;
351 TF_RETURN_IF_ERROR(TF_TensorToTensor(tf_tensor.get(), &t));
352 *result = tensorflow::TensorInterface(std::move(t));
353 return OkStatus();
354 }
355
356 // Handle attributes.
357 //
358 // Refer to tensorflow/core/framework/attr_value.proto and
359 // tensorflow/c/eager/c_api.h.
360 //
361 // Note we currently do not support the following attribute value types:
362 // TFE_OpSetAttrFunction
363 // TFE_OpSetAttrFunctionName
PrepareAttributes(EagerOperation * eager_op,const OpAttrsRef & attrs,HostContext * host,EagerContext * eager_ctx)364 static tensorflow::Status PrepareAttributes(EagerOperation* eager_op,
365 const OpAttrsRef& attrs,
366 HostContext* host,
367 EagerContext* eager_ctx) {
368 tensorflow::Status status;
369 attrs.IterateEntries([eager_op, eager_ctx, status_ptr = &status, host,
370 &attrs](const OpAttrsRawEntry& entry) {
371 // TFE does not expect a device attribute.
372 assert(strcmp(entry.name, "device") != 0);
373 if (IsUnusedAttribute(entry.name)) {
374 return;
375 } else if (entry.IsArray()) {
376 if (entry.element_count == 0) {
377 if (entry.type == OpAttrType::CHAR) {
378 // Empty string.
379 std::string empty_str;
380 *status_ptr = eager_op->SetAttrString(entry.name, empty_str.data(),
381 empty_str.size());
382 } else {
383 // Empty array of other types.
384 AttrValue empty_attr_value;
385 eager_op->MutableAttrs()->Set(entry.name, empty_attr_value);
386 }
387 } else if (entry.type == OpAttrType::CHAR) {
388 string_view attr_value = attrs.GetStringAsserting(entry.name);
389 *status_ptr = eager_op->SetAttrString(entry.name, attr_value.data(),
390 attr_value.size());
391 } else if (entry.type == OpAttrType::FUNC) {
392 string_view attr_value = attrs.GetFuncNameAsserting(entry.name);
393 *status_ptr = eager_op->SetAttrFunctionName(
394 entry.name, attr_value.data(), attr_value.size());
395 } else if (entry.type == OpAttrType::I64) {
396 llvm::ArrayRef<int64_t> int_array =
397 attrs.GetArrayAsserting<int64_t>(entry.name);
398 *status_ptr = eager_op->SetAttrIntList(entry.name, int_array.data(),
399 int_array.size());
400 } else if (entry.type == OpAttrType::F32) {
401 llvm::ArrayRef<float> float_array =
402 attrs.GetArrayAsserting<float>(entry.name);
403 *status_ptr = eager_op->SetAttrFloatList(entry.name, float_array.data(),
404 float_array.size());
405 } else if (entry.type == OpAttrType::BOOL) {
406 llvm::ArrayRef<bool> bool_array =
407 attrs.GetArrayAsserting<bool>(entry.name);
408 // SetAttrBoolList expects const unsigned char*, not const bool*.
409 std::vector<unsigned char> bool_char_array(bool_array.begin(),
410 bool_array.end());
411 *status_ptr = eager_op->SetAttrBoolList(
412 entry.name, bool_char_array.data(), bool_char_array.size());
413 } else if (entry.type == OpAttrType::DTYPE) {
414 const auto& op_attr = attrs.GetRawAsserting(entry.name);
415 assert(op_attr.IsArray());
416
417 // DTypes in BEF attributes are tfrt::DType enums. So we need
418 // to convert then to tensorflow data types first.
419 auto bef_dtypes = llvm::makeArrayRef(
420 static_cast<const tfrt::DType*>(op_attr.GetData()),
421 op_attr.element_count);
422
423 llvm::SmallVector<tensorflow::DataType, 4> tf_dtypes;
424 tf_dtypes.reserve(bef_dtypes.size());
425 for (auto bef_dtype : bef_dtypes) {
426 tf_dtypes.push_back(ConvertBefAttrTypeToTfDataType(bef_dtype));
427 }
428
429 *status_ptr = eager_op->SetAttrTypeList(entry.name, tf_dtypes.data(),
430 tf_dtypes.size());
431 } else {
432 *status_ptr =
433 tensorflow::errors::Internal("unsupported array attribute type");
434 }
435 } else {
436 if (entry.type == OpAttrType::I64) {
437 int64_t attr_value = attrs.GetAsserting<int64_t>(entry.name);
438 *status_ptr = eager_op->SetAttrInt(entry.name, attr_value);
439 } else if (entry.type == OpAttrType::F32) {
440 float attr_value = attrs.GetAsserting<float>(entry.name);
441 *status_ptr = eager_op->SetAttrFloat(entry.name, attr_value);
442 } else if (entry.type == OpAttrType::BOOL) {
443 bool attr_value = attrs.GetAsserting<bool>(entry.name);
444 *status_ptr = eager_op->SetAttrBool(entry.name, attr_value);
445 } else if (entry.type == OpAttrType::DTYPE) {
446 OpAttrType op_attr_type = attrs.GetAsserting<OpAttrType>(entry.name);
447 DataType tf_dtype = ConvertToTfDataType(op_attr_type);
448 *status_ptr = eager_op->SetAttrType(entry.name, tf_dtype);
449 } else if (entry.type == OpAttrType::SHAPE) {
450 tfrt::ShapeAttr shape_attr =
451 attrs.GetAsserting<tfrt::ShapeAttr>(entry.name);
452 if (shape_attr.HasRank()) {
453 *status_ptr = eager_op->SetAttrShape(
454 entry.name, shape_attr.GetShape().data(), shape_attr.GetRank());
455 } else {
456 *status_ptr = eager_op->SetAttrShape(entry.name, /*dims=*/nullptr,
457 /*num_dims=*/-1);
458 }
459 } else if (entry.type == OpAttrType::DENSE) {
460 DenseAttr dense_attr = attrs.GetAsserting<DenseAttr>(entry.name);
461 tensorflow::TensorInterface interface;
462 *status_ptr =
463 DecodeDenseAttrToTensorInterface(dense_attr, host, &interface);
464 if (!status_ptr->ok()) return;
465 *status_ptr = eager_op->SetAttrTensor(entry.name, &interface);
466 } else if (entry.type == OpAttrType::AGGREGATE) {
467 AggregateAttr list_attr = attrs.GetAsserting<AggregateAttr>(entry.name);
468 int num_values = list_attr.GetNumElements();
469
470 // Insert a dummy list attribute to the NodeDef if the aggregate attr
471 // is empty. This is needed because the ValidateNodeDef method checks
472 // the encoded_attr_ map for expected attributes, specified in the
473 // OpDef.
474 if (num_values == 0) {
475 // The int type is just a placeholder and doesn't matter.
476 std::vector<int> dummy_attr;
477 eager_op->MutableAttrs()->Set(
478 entry.name, gtl::ArraySlice<const int>(dummy_attr.data(), 0));
479 return;
480 }
481
482 // It is guaranteed that items in one list attribute have the same
483 // type, though their sizes can be different. In particular,
484 // list(TensorShape) and list(Tensor) attribute types have to be
485 // encoded as AggregateAttr.
486 auto attr_base = list_attr.GetAttribute(0);
487 if (IsDataTypeAttribute(attr_base.type()) &&
488 GetDataType(attr_base.type()) == tfrt::DType::String) {
489 // Handle list(string).
490 llvm::SmallVector<const void*, 8> values;
491 llvm::SmallVector<size_t, 8> lengths;
492 values.reserve(num_values);
493 lengths.reserve(num_values);
494 for (int i = 0; i < num_values; ++i) {
495 auto string_attr = list_attr.GetAttributeOfType<StringAttr>(i);
496 values.push_back(string_attr.GetValue().data());
497 lengths.push_back(string_attr.GetValue().size());
498 }
499 *status_ptr = eager_op->SetAttrStringList(entry.name, values.data(),
500 lengths.data(), num_values);
501 } else if (IsFuncAttribute(attr_base.type())) {
502 std::vector<const AbstractOperation*> funcs(num_values);
503 for (int i = 0; i < num_values; ++i) {
504 auto func_attr = list_attr.GetAttributeOfType<FuncAttr>(i);
505 // TODO(chuanhao): Creating a EagerOperation here is expensive.
506 // consider using AttrBuilder to set attribute directly.
507 ImmediateExecutionOperation* new_op = eager_ctx->CreateOperation();
508 auto func_name = func_attr.GetFunctionName();
509 *status_ptr = new_op->Reset(func_name.str().c_str(),
510 /*raw_device_name=*/nullptr);
511 funcs[i] = new_op;
512 }
513 *status_ptr =
514 eager_op->SetAttrFunctionList(entry.name, absl::MakeSpan(funcs));
515 } else if (attr_base.type() == BEFAttributeType::kShape) {
516 // Handle list(TensorShape).
517 llvm::SmallVector<int, 8> ranks;
518 llvm::SmallVector<const int64_t*, 8> dims;
519 ranks.reserve(num_values);
520 dims.reserve(num_values);
521 for (int i = 0; i < num_values; ++i) {
522 auto shape_attr = list_attr.GetAttributeOfType<ShapeAttr>(i);
523 if (shape_attr.HasRank()) {
524 ranks.push_back(shape_attr.GetRank());
525 dims.push_back(shape_attr.GetShape().data());
526 } else {
527 ranks.push_back(-1);
528 dims.push_back(nullptr);
529 }
530 }
531 *status_ptr = eager_op->SetAttrShapeList(entry.name, dims.data(),
532 ranks.data(), num_values);
533 } else {
534 *status_ptr =
535 tensorflow::errors::Internal("unsupported list attribute type");
536 }
537 } else {
538 *status_ptr =
539 tensorflow::errors::Internal("unsupported scalar attribute type");
540 }
541 }
542 });
543 return status;
544 }
545
CallEagerExecute(const tfrt::ExecutionContext & exec_ctx,EagerContext * eager_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<TensorHandle * > input_tensor_handles,const OpAttrsRef & attrs,llvm::MutableArrayRef<tensorflow::AbstractTensorHandle * > result_tensor_handles)546 Status CallEagerExecute(const tfrt::ExecutionContext& exec_ctx,
547 EagerContext* eager_ctx, const char* op_name,
548 const char* device_name,
549 llvm::ArrayRef<TensorHandle*> input_tensor_handles,
550 const OpAttrsRef& attrs,
551 llvm::MutableArrayRef<tensorflow::AbstractTensorHandle*>
552 result_tensor_handles) {
553 assert(eager_ctx != nullptr && "EagerContext is NULL");
554
555 // Create TF EagerOperation.
556 OwnedEagerOperation eager_op{new EagerOperation(eager_ctx)};
557 TF_RETURN_IF_ERROR(eager_op->Reset(op_name, device_name));
558
559 // Handle inputs.
560 for (TensorHandle* input_tensor : input_tensor_handles) {
561 TF_RETURN_IF_ERROR(eager_op->AddInput(input_tensor));
562 }
563
564 // Handle attributes.
565 auto* host = exec_ctx.host();
566 TF_RETURN_IF_ERROR(PrepareAttributes(eager_op.get(), attrs, host, eager_ctx));
567
568 int num_retvals = result_tensor_handles.size();
569 TF_RETURN_IF_ERROR(eager_op->Execute(
570 absl::MakeSpan(result_tensor_handles.data(), num_retvals), &num_retvals));
571
572 return OkStatus();
573 }
574
ShouldAddHostContextAttr(const char * op_name)575 static bool ShouldAddHostContextAttr(const char* op_name) {
576 // NOTE(rachelim): In the future, if more ops require this, instead of
577 // checking against a whitelist of op names, we could check whether the op
578 // contains an attribute called `host_ptr`.
579 return strcmp(op_name, "TFRTMakeIterator") == 0;
580 }
581
582 // TODO(zhangqiaorjc): Unify implementation with RuntimeFallbackKernel.
RuntimeFallbackExecute(const tfrt::ExecutionContext & exec_ctx,EagerContext * eager_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<Tensor * > arguments,const OpAttrsRef & attrs,llvm::MutableArrayRef<RCReference<AsyncValue>> results)583 AsyncValueRef<Chain> RuntimeFallbackExecute(
584 const tfrt::ExecutionContext& exec_ctx, EagerContext* eager_ctx,
585 const char* op_name, const char* device_name,
586 llvm::ArrayRef<Tensor*> arguments, const OpAttrsRef& attrs,
587 llvm::MutableArrayRef<RCReference<AsyncValue>> results) {
588 auto emit_error = [&exec_ctx, results](const tensorflow::Status& status) {
589 // Set the correct TFRT error code according to the error propagated from
590 // runtime fallback execution.
591 auto error =
592 EmitErrorAsync(exec_ctx, status.error_message(),
593 tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
594 // Set all results to error.
595 std::fill(results.begin(), results.end(), error);
596 return error;
597 };
598
599 llvm::SmallVector<TensorHandle*, 4> input_tensor_handles;
600 input_tensor_handles.reserve(arguments.size());
601 for (Tensor* input_tensor : arguments) {
602 input_tensor_handles.push_back(
603 llvm::cast<RuntimeFallbackTensor>(input_tensor)->GetTensorHandle());
604 }
605
606 int num_retvals = results.size();
607 llvm::SmallVector<tensorflow::AbstractTensorHandle*, 4> result_tensor_handles(
608 num_retvals);
609 Status status;
610 if (!ShouldAddHostContextAttr(op_name)) {
611 status =
612 CallEagerExecute(exec_ctx, eager_ctx, op_name, device_name,
613 input_tensor_handles, attrs, result_tensor_handles);
614 } else {
615 // Wrap the HostContext pointer in an attribute. This is necessary for
616 // TF ops that require the TFRT HostContext to function. These kernels
617 // should not create their own HostContexts.
618 // TODO(rachelim): Support copying over non-host_ptr attrs, if there are
619 // any.
620 assert(attrs.GetNumEntries() == 1);
621 OpAttrs updated;
622
623 updated.Set(kHostContextPtrAttrName,
624 reinterpret_cast<int64_t>(exec_ctx.host()));
625 status = CallEagerExecute(
626 exec_ctx, eager_ctx, op_name, device_name, input_tensor_handles,
627 OpAttrsRef(std::move(updated)), result_tensor_handles);
628 }
629 if (!status.ok()) return emit_error(status);
630
631 auto host = exec_ctx.host();
632 for (int i = 0; i < num_retvals; ++i) {
633 auto expected_fallback_tensor =
634 CreateRuntimeFallbackTensorFromTfTensorHandle(
635 OwnedTensorHandle{
636 TensorHandleFromInterface(result_tensor_handles[i])},
637 host);
638 if (!expected_fallback_tensor)
639 results[i] = EmitErrorAsync(
640 exec_ctx, tfrt::StrCat(expected_fallback_tensor.takeError()));
641 else
642 results[i] = tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
643 host, std::move(*expected_fallback_tensor));
644 }
645
646 return tfrt::GetReadyChain();
647 }
648
RuntimeFallbackExecute(const tfrt::ExecutionContext & exec_ctx,const char * op_name,const char * device_name,llvm::ArrayRef<Tensor * > arguments,const OpAttrsRef & attrs,llvm::MutableArrayRef<RCReference<AsyncValue>> results)649 AsyncValueRef<Chain> RuntimeFallbackExecute(
650 const tfrt::ExecutionContext& exec_ctx, const char* op_name,
651 const char* device_name, llvm::ArrayRef<Tensor*> arguments,
652 const OpAttrsRef& attrs,
653 llvm::MutableArrayRef<RCReference<AsyncValue>> results) {
654 // Get EagerContext.
655 auto eager_ctx_expected = GetEagerContext(exec_ctx);
656 if (!eager_ctx_expected) {
657 auto error = EmitErrorAsync(exec_ctx, eager_ctx_expected.takeError(),
658 tfrt::ErrorCode::kUnknown);
659 // Set all results to error.
660 std::fill(results.begin(), results.end(), error);
661 return std::move(error);
662 }
663 EagerContext* eager_ctx = eager_ctx_expected.get();
664
665 return RuntimeFallbackExecute(exec_ctx, eager_ctx, op_name, device_name,
666 arguments, attrs, results);
667 }
668
669 // Kernel to delegate to the current TF runtime kernel.
670 //
671 // Example usage in MLIR:
672 //
673 // %c2, %tft_c = "tfd.delegate_kernel"(%c1, %tft_a, %tft_b) {op_name = "MatMul"}
674 // : (!hex.chain, !tfd.tf_tensor, !tfd.tf_tensor) -> (!hex.chain,
675 // !tfd.tf_tensor)
676 // TODO(jingdong): Enqueue the TFE kernel execution as blocking task to the
677 // ConcurrentWorkQueue.
RuntimeFallbackKernel(Argument<Chain> in_chain,RemainingArguments input_tensors,Result<Chain> out_chain,RemainingResults output_tensors,StringAttribute op_name,RemainingAttributes remaining_attributes,KernelErrorHandler handler,const ExecutionContext & exec_ctx)678 static void RuntimeFallbackKernel(
679 Argument<Chain> in_chain, RemainingArguments input_tensors,
680 Result<Chain> out_chain, RemainingResults output_tensors,
681 StringAttribute op_name, RemainingAttributes remaining_attributes,
682 KernelErrorHandler handler, const ExecutionContext& exec_ctx) {
683 HostContext* host = exec_ctx.host();
684 tfrt::ResourceContext* resource_context = exec_ctx.resource_context();
685 EagerContextResource* eager_context_resource =
686 resource_context->GetOrCreateResource<EagerContextResource>(
687 tensorflow::tfd::kEagerContextResourceName);
688 tfrt::Expected<EagerContext*> eager_ctx_expected =
689 eager_context_resource->GetTFEagerContext();
690 if (!eager_ctx_expected) {
691 handler.ReportError("eager_ctx_expected.takeError()");
692 return;
693 }
694 EagerContext* eager_ctx = eager_ctx_expected.get();
695
696 // Construct TF EagerOperation.
697 // Need to copy op_name to a std::string to ensure the string is
698 // null-terminated.
699 std::string op_name_str = [&] {
700 auto view = op_name.get();
701 view.consume_front("tf.");
702 return view.str();
703 }();
704
705 OwnedEagerOperation eager_op{new EagerOperation(eager_ctx)};
706 TFD_REPORT_AND_RETURN_IF_ERROR(
707 handler,
708 eager_op->Reset(op_name_str.c_str(), /*raw_device_name=*/nullptr));
709
710 // Handle inputs.
711 for (AsyncValue* input_tensor_av : input_tensors.values()) {
712 auto input_tensor_handle =
713 input_tensor_av->get<RuntimeFallbackTensor>().GetTensorHandle();
714 TFD_REPORT_AND_RETURN_IF_ERROR(handler,
715 eager_op->AddInput(input_tensor_handle));
716 }
717
718 // Handle TF op attributes.
719 // TODO(zhangqiaorjc): Encode TF attributes using native MLIR attribute types.
720 assert(remaining_attributes.size() % 2 == 0);
721 int num_tf_attrs = remaining_attributes.size() / 2;
722 for (int i = 0; i < num_tf_attrs; ++i) {
723 // Each TF attribute is represented as a pair of name and value strings.
724 // Make a copy for `attr_name` to ensure null-termination.
725 std::string attr_name =
726 remaining_attributes.GetStringAttribute(i * 2).str();
727 absl::string_view attr_value = ToAbslStringView(
728 remaining_attributes.GetStringAttribute(i * 2 + 1).get());
729 std::vector<absl::string_view> value_split =
730 tfd::AttrValueSplit(attr_value);
731
732 // Handle different TF attribute types.
733 if (value_split[0] == "string") {
734 TFD_REPORT_AND_RETURN_IF_ERROR(
735 handler,
736 eager_op->SetAttrString(attr_name.c_str(), value_split[1].data(),
737 value_split[1].size()));
738 } else if (value_split[0] == "bool") {
739 bool bool_val;
740 TFD_REPORT_AND_RETURN_IF_ERROR(
741 handler, ParseBoolAttrValue(value_split[1], &bool_val));
742 TFD_REPORT_AND_RETURN_IF_ERROR(
743 handler, eager_op->SetAttrBool(attr_name.c_str(), bool_val));
744 } else if (value_split[0] == "int") {
745 int64_t int_val;
746 TFD_REPORT_AND_RETURN_IF_ERROR(
747 handler, ParseIntAttrValue(value_split[1], &int_val));
748 TFD_REPORT_AND_RETURN_IF_ERROR(
749 handler, eager_op->SetAttrInt(attr_name.c_str(), int_val));
750 } else if (value_split[0] == "tftensor") {
751 tensorflow::Tensor t;
752 TFD_REPORT_AND_RETURN_IF_ERROR(handler,
753 ParseTensorAttrValue(value_split[1], &t));
754 tensorflow::TensorInterface interface(t);
755 TFD_REPORT_AND_RETURN_IF_ERROR(
756 handler, eager_op->SetAttrTensor(attr_name.c_str(), &interface));
757 } else if (value_split[0] == "tfdtype") {
758 DataType dtype;
759 TFD_REPORT_AND_RETURN_IF_ERROR(handler,
760 ParseTfDataType(value_split[1], &dtype));
761 TFD_REPORT_AND_RETURN_IF_ERROR(
762 handler, eager_op->SetAttrType(attr_name.c_str(), dtype));
763 } else if (value_split[0] == "tfshape") {
764 std::vector<int64_t> dims;
765 TFD_REPORT_AND_RETURN_IF_ERROR(
766 handler, ParseTensorShapeAttrValue(value_split[1], &dims));
767 TFD_REPORT_AND_RETURN_IF_ERROR(
768 handler,
769 eager_op->SetAttrShape(attr_name.c_str(), dims.data(), dims.size()));
770 } else {
771 handler.ReportError("attribute type not yet supported");
772 return;
773 }
774 }
775
776 // Invoke the TF EagerOperation.
777 int num_retvals = output_tensors.size();
778 llvm::SmallVector<tensorflow::AbstractTensorHandle*, 4> retvals(num_retvals);
779
780 tensorflow::Status status = eager_op->Execute(
781 absl::MakeSpan(retvals.data(), num_retvals), &num_retvals);
782 TFD_REPORT_AND_RETURN_IF_ERROR(handler, status);
783
784 // Handle outputs.
785 if (num_retvals != output_tensors.size()) {
786 handler.ReportError("Incorrect number of output values");
787 return;
788 }
789 for (int i = 0; i < num_retvals; ++i) {
790 OwnedTensorHandle owned_th{TensorHandleFromInterface(retvals[i])};
791 if (!owned_th) handler.ReportError("TensorHandleFromInterface failed");
792 auto fallback_tensor = CreateRuntimeFallbackTensorFromTfTensorHandle(
793 std::move(owned_th), host);
794 if (!fallback_tensor) {
795 output_tensors[i] = tfrt::MakeErrorAsyncValueRef(
796 host, tfrt::StrCat(fallback_tensor.takeError()));
797 } else {
798 output_tensors[i] =
799 tfrt::MakeAvailableAsyncValueRef<RuntimeFallbackTensor>(
800 host, std::move(*fallback_tensor));
801 }
802 }
803 out_chain.Set(in_chain);
804 }
805
EmitErrorAndSetInResults(const tfrt::ExecutionContext & exec_ctx,const tfrt::DecodedDiagnostic & error,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)806 static void EmitErrorAndSetInResults(
807 const tfrt::ExecutionContext& exec_ctx,
808 const tfrt::DecodedDiagnostic& error,
809 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
810 auto error_av = tfrt::EmitErrorAsync(exec_ctx, error.message, error.code);
811 std::fill(results.begin(), results.end(), error_av);
812 }
813
814 // Convert the tfrt::TensorHandle to tensorflow::Tensor. `device` is the target
815 // device for the converted tensorflow::Tensor.
816 //
817 // TODO(tfrt-devs): Currently the target device can only be CPU. We need to add
818 // support for more devices.
CoreRTTensorHandleToFallbackTensorInternal(llvm::ArrayRef<tfrt::AsyncValue * > tensorhandle_args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> tf_tensor_results,tfrt::string_view device,const tfrt::ExecutionContext & exec_ctx)819 void CoreRTTensorHandleToFallbackTensorInternal(
820 llvm::ArrayRef<tfrt::AsyncValue*> tensorhandle_args,
821 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>>
822 tf_tensor_results,
823 tfrt::string_view device, const tfrt::ExecutionContext& exec_ctx) {
824 assert(tensorhandle_args.size() == tf_tensor_results.size());
825
826 auto set_result = [&](tfrt::RCReference<tfrt::AsyncValue>& result,
827 llvm::Expected<tensorflow::Tensor> tf_tensor) {
828 auto result_ref = tfrt::MakeUnconstructedAsyncValueRef<
829 tensorflow::tfrt_stub::FallbackTensor>();
830 if (!tf_tensor) {
831 result_ref.SetError(tfrt::StrCat(tf_tensor.takeError()));
832 } else {
833 result_ref.emplace(std::move(tf_tensor.get()));
834 }
835 result = std::move(result_ref);
836 };
837
838 auto maybe_convert_runtime_fallback_tensor =
839 [&exec_ctx](
840 tfrt::AsyncValueRef<Tensor> tensor_avref,
841 const tfrt::Device& src_device,
842 const tfrt::Device& dst_device) -> tfrt::AsyncValueRef<tfrt::Tensor> {
843 // TODO(tfrt-devs): Remove implicit conversion in this kernel since it will
844 // have extra overheads.
845 // Convert RuntimeFallbackTensor to KernelFallbackTensor before calling
846 // into kernel fallback. That is because today kernel fallback cannot read
847 // tensor from runtime fallback tensor because we don't want kernel
848 // fallback to depend on runtime fallback.
849 assert(tensor_avref.IsAvailable());
850 assert(!tensor_avref.IsError());
851 auto& tensor = tensor_avref.get();
852 if (!tensor.IsTensorType(DenseHostTensor::kTensorType) ||
853 !src_device.IsDeviceType(tfrt::CpuDevice::kDeviceType) ||
854 !dst_device.IsDeviceType(tfrt::CpuDevice::kDeviceType)) {
855 return tfrt::ConvertTensor(exec_ctx, tensor,
856 /*src=*/src_device,
857 /*dst=*/dst_device,
858 KernelFallbackTensor::kTensorType);
859 }
860 return tensor_avref;
861 };
862
863 auto dst_device =
864 exec_ctx.host()->GetDeviceManager()->GetDeviceRef<tfrt::Device>(device);
865
866 // Retrieve the underlying pointer of tfrt::Tensor. We don't need to do
867 // extra ownership management here because KernelFallbackExecuteCompat()
868 // will always convert it to tensorflow::Tensor which is itself refcounted.
869 for (int i = 0; i < tensorhandle_args.size(); ++i) {
870 if (!dst_device) {
871 tf_tensor_results[i] = tfrt::MakeErrorAsyncValueRef(
872 tfrt::StrCat("Failed to find device with name ", device));
873 continue;
874 }
875 auto& tensor_handle = tensorhandle_args[i]->get<tfrt::TensorHandle>();
876 assert(tensor_handle.IsDeviceAvailable());
877 assert(!tensor_handle.IsDeviceError());
878
879 auto* tensor_av = tensor_handle.GetAsyncTensor();
880 auto tensor_avref = tfrt::AsyncValueRef<Tensor>(FormRef(tensor_av));
881
882 auto& src_device = *tensor_handle.GetAvailableDevice();
883 AsyncValueRef<Tensor> knfb_tensor;
884 if (!tensor_av->IsAvailable()) {
885 auto ind_av = tfrt::MakeIndirectAsyncValue(exec_ctx.host());
886 knfb_tensor = AsyncValueRef<Tensor>(ind_av.CopyRef());
887 tensor_av->AndThen(
888 [tensor_avref = std::move(tensor_avref), ind_av = std::move(ind_av),
889 &src_device, dst_device = dst_device.CopyRef(),
890 maybe_convert_runtime_fallback_tensor, exec_ctx]() mutable {
891 ind_av->ForwardTo(maybe_convert_runtime_fallback_tensor(
892 std::move(tensor_avref), src_device, *dst_device));
893 });
894 } else {
895 knfb_tensor = maybe_convert_runtime_fallback_tensor(
896 std::move(tensor_avref), src_device, *dst_device);
897 }
898
899 if (!knfb_tensor.IsAvailable()) {
900 auto result_ref = tfrt::MakeIndirectAsyncValue(exec_ctx.host());
901 tf_tensor_results[i] = result_ref;
902 auto knfb_tensor_av = knfb_tensor.GetAsyncValue();
903 knfb_tensor_av->AndThen([knfb_tensor = std::move(knfb_tensor),
904 result_ref = std::move(result_ref),
905 dst_device = dst_device.CopyRef(),
906 exec_ctx]() mutable {
907 if (knfb_tensor.IsError()) {
908 result_ref->ForwardTo(std::move(knfb_tensor));
909 return;
910 }
911 auto expected_tf_tensor =
912 TFRTTensorToTFTensor(knfb_tensor.get(), exec_ctx.host());
913 if (!expected_tf_tensor) {
914 auto error =
915 tfrt::EmitErrorAsync(exec_ctx, expected_tf_tensor.takeError());
916 result_ref->ForwardTo(std::move(error));
917 } else {
918 auto tf_tensor_ref = tfrt::MakeAvailableAsyncValueRef<
919 tensorflow::tfrt_stub::FallbackTensor>(
920 std::move(expected_tf_tensor.get()));
921 result_ref->ForwardTo(std::move(tf_tensor_ref));
922 }
923 });
924 } else {
925 set_result(tf_tensor_results[i],
926 TFRTTensorToTFTensor(knfb_tensor.get(), exec_ctx.host()));
927 }
928 }
929 }
930
931 // Returns true if the tensorflow::DataType is trivially copyable.
IsTriviallyCopyableTensorflowDataType(tensorflow::DataType dtype)932 static bool IsTriviallyCopyableTensorflowDataType(tensorflow::DataType dtype) {
933 static const auto* const non_trivially_copyable_dtypes =
934 new absl::flat_hash_set<tensorflow::DataType>{
935 tensorflow::DataType::DT_STRING, tensorflow::DataType::DT_RESOURCE,
936 tensorflow::DataType::DT_VARIANT};
937 return !non_trivially_copyable_dtypes->contains(dtype);
938 }
939
ConstDenseTensor(tfrt::DenseAttr value,const tfrt::ExecutionContext & context)940 static llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstDenseTensor(
941 tfrt::DenseAttr value, const tfrt::ExecutionContext& context) {
942 auto dtype = GetTfDataType(tfrt::DType(value.dtype()));
943 // The data type must be trivially copyable so that we can use memcpy.
944 DCHECK(IsTriviallyCopyableTensorflowDataType(dtype));
945 tensorflow::Tensor tensor(dtype, tensorflow::TensorShape(value.shape()));
946 std::memcpy(tensor.data(), value.GetElements(), tensor.TotalBytes());
947 return tensorflow::tfrt_stub::FallbackTensor(tensor);
948 }
949
ConstStringTensor(tfrt::ArrayAttr shape,tfrt::AggregateAttr value,const ExecutionContext & context)950 static llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstStringTensor(
951 tfrt::ArrayAttr shape, tfrt::AggregateAttr value,
952 const ExecutionContext& context) {
953 llvm::SmallVector<int64_t> dims;
954 auto tfrt_tensor_shape = tfrt::TensorShape(shape.GetValue<int64_t>());
955 tfrt_tensor_shape.GetDimensions(&dims);
956 tensorflow::Tensor tensor(tensorflow::DT_STRING,
957 tensorflow::TensorShape(dims));
958 auto len = tensor.NumElements();
959 auto from = value;
960 auto to = tensor.flat<tensorflow::tstring>();
961 if (from.GetNumElements() == 1) {
962 // All elements are the same, and only one element is saved in BEF.
963 for (size_t i = 0; i < len; ++i) {
964 to(i) =
965 ToAbslStringView(from.GetAttributeOfType<StringAttr>(0).GetValue());
966 }
967 } else {
968 assert(len == from.GetNumElements());
969 for (size_t i = 0; i < len; ++i) {
970 to(i) =
971 ToAbslStringView(from.GetAttributeOfType<StringAttr>(i).GetValue());
972 }
973 }
974 return tensorflow::tfrt_stub::FallbackTensor(tensor);
975 }
976
977 // The BEF kernel for tfrt::TensorHandle to tensorflow::Tensor conversion.
CoreRTTensorHandleToFallbackTensor(RemainingArguments args,RemainingResults results,StringAttr device,const tfrt::ExecutionContext & exec_ctx)978 void CoreRTTensorHandleToFallbackTensor(
979 RemainingArguments args, RemainingResults results, StringAttr device,
980 const tfrt::ExecutionContext& exec_ctx) {
981 tensorflow::profiler::TraceMe trace_me(
982 "corert_tensorhandle_to_fallback_tensor");
983 trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
984 return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
985 });
986
987 CoreRTTensorHandleToFallbackTensorInternal(args.values(), results.values(),
988 device.GetValue(), exec_ctx);
989 }
990
991 // Convert the tensorflow::Tensor to tfrt::TensorHandle. `device` is the device
992 // for the input tensorflow::Tensor.
993 //
994 // TODO(tfrt-devs): Currently the input device can only be CPU. We need to add
995 // support for more devices.
FallbackTensorToCoreRTTensorHandleInternal(llvm::ArrayRef<tfrt::AsyncValue * > tf_tensor_args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> tensorhandle_results,absl::string_view device,const tfrt::ExecutionContext & exec_ctx)996 static void FallbackTensorToCoreRTTensorHandleInternal(
997 llvm::ArrayRef<tfrt::AsyncValue*> tf_tensor_args,
998 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>>
999 tensorhandle_results,
1000 absl::string_view device, const tfrt::ExecutionContext& exec_ctx) {
1001 auto* host = exec_ctx.host();
1002
1003 assert(tf_tensor_args.size() == tensorhandle_results.size());
1004 for (int i = 0; i < tf_tensor_args.size(); ++i) {
1005 auto* av = tf_tensor_args[i];
1006 auto& tf_tensor = av->get<tensorflow::tfrt_stub::FallbackTensor>().tensor();
1007 AsyncValueRef<tfrt::Tensor> kernel_fallback_tensor =
1008 tfrt::MakeAvailableAsyncValueRef<KernelFallbackTensor>(tf_tensor);
1009 auto metadata = kernel_fallback_tensor.get().metadata();
1010
1011 tensorhandle_results[i] =
1012 tfrt::MakeAvailableAsyncValueRef<tfrt::TensorHandle>(
1013 host,
1014 host->GetDeviceManager()->GetDeviceRef<tfrt::Device>(
1015 {device.data(), device.size()}),
1016 metadata, std::move(kernel_fallback_tensor));
1017 }
1018 }
1019
1020 // The BEF kernel for tensorflow::Tensor to tfrt::TensorHandle conversion.
FallbackTensorToCoreRTTensorHandle(RemainingArguments args,RemainingResults results,StringAttr device,const tfrt::ExecutionContext & exec_ctx)1021 void FallbackTensorToCoreRTTensorHandle(
1022 RemainingArguments args, RemainingResults results, StringAttr device,
1023 const tfrt::ExecutionContext& exec_ctx) {
1024 tensorflow::profiler::TraceMe trace_me(
1025 "fallback_tensor_to_corert_tensorhandle");
1026 trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
1027 return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
1028 });
1029
1030 FallbackTensorToCoreRTTensorHandleInternal(
1031 args.values(), results.values(), ToAbslStringView(device.GetValue()),
1032 exec_ctx);
1033 }
1034
Predicate(const tensorflow::tfrt_stub::FallbackTensor & input,const tfrt::ExecutionContext & exec_ctx)1035 static llvm::Expected<bool> Predicate(
1036 const tensorflow::tfrt_stub::FallbackTensor& input,
1037 const tfrt::ExecutionContext& exec_ctx) {
1038 const auto& tensor = input.tensor();
1039 if (TensorShapeUtils::IsScalar(tensor.shape())) {
1040 switch (tensor.dtype()) {
1041 #define CASE(T) \
1042 case DataTypeToEnum<T>::value: \
1043 return tensor.scalar<T>()() != 0;
1044
1045 CASE(float);
1046 CASE(double);
1047 CASE(uint8);
1048 CASE(int8);
1049 CASE(int16);
1050 CASE(int32);
1051 CASE(int64_t);
1052 CASE(bool);
1053 #undef CASE
1054 case DT_STRING:
1055 return !tensor.scalar<tstring>()().empty();
1056 default:
1057 return tfrt::MakeStringError(DataTypeString(tensor.dtype()),
1058 " cannot be converted to a boolean");
1059 }
1060 }
1061
1062 return tensor.NumElements() > 0;
1063 }
1064
PrintFallbackTensor(const tensorflow::tfrt_stub::FallbackTensor & arg,const tfrt::Chain & ch)1065 tfrt::Chain PrintFallbackTensor(
1066 const tensorflow::tfrt_stub::FallbackTensor& arg, const tfrt::Chain& ch) {
1067 std::string message;
1068 llvm::raw_string_ostream(message) << arg.tensor().DebugString() << "\n";
1069 printf("%s", message.c_str());
1070 fflush(stdout);
1071 return tfrt::Chain();
1072 }
1073
1074 // The legacy kernel implementation that dispatches runtime fallback operations.
1075 // Since the arguments and results are tensorflow::Tensors, internally it
1076 // does conversions between RuntimeFallbackTensor and tensorflow::Tensor.
RuntimeFallbackExecuteOp(RemainingArguments args,RemainingResults results,StringAttr device_attr,AggregateAttr op_attr_array,AggregateAttr op_func_attr_array,StringAttr op_name_attr,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const ExecutionContext & exec_ctx)1077 static void RuntimeFallbackExecuteOp(
1078 RemainingArguments args, RemainingResults results, StringAttr device_attr,
1079 AggregateAttr op_attr_array, AggregateAttr op_func_attr_array,
1080 StringAttr op_name_attr, tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
1081 const ExecutionContext& exec_ctx) {
1082 auto set_error = [&exec_ctx, results](tfrt::string_view msg) {
1083 auto error_av = EmitErrorAsync(exec_ctx, msg, tfrt::ErrorCode::kUnknown);
1084 // Set all results to error.
1085 for (int i = 0, n = results.size(); i < n; ++i) results[i] = error_av;
1086 };
1087
1088 auto op_name = op_name_attr.GetValue();
1089 op_name.consume_front("tf.");
1090
1091 // The device name might not be in the format expected by tensorflow. In that
1092 // case we change it to the correct format. Currently we only support CPU.
1093 //
1094 // TODO(tfrt-devs): Make sure device names passed to fallback kernels are in
1095 // the tensorflow format.
1096 std::string device_name = device_attr.GetValue().str();
1097 if (!absl::StartsWith(device_name, "/")) device_name = kDefaultCpuDevice;
1098
1099 auto* host = exec_ctx.host();
1100
1101 // Set up OpAttrs.
1102 tfrt::OpAttrs op_attrs;
1103 tfrt::SetUpOpAttrs(op_attr_array, &op_attrs);
1104
1105 // Set up OpAttrs specifically for function attributes.
1106 tfrt::SetUpOpFuncAttrs(op_func_attr_array, &op_attrs);
1107
1108 // Get EagerContext.
1109 auto eager_ctx_expected = GetEagerContext(exec_ctx);
1110 if (!eager_ctx_expected) {
1111 set_error(tfrt::StrCat(eager_ctx_expected.takeError()));
1112 return;
1113 }
1114 EagerContext* eager_ctx = eager_ctx_expected.get();
1115
1116 // Get device.
1117 Device* device = nullptr;
1118 Status s = eager_ctx->local_device_mgr()->LookupDevice(device_name, &device);
1119 if (!s.ok()) {
1120 // The device name can be invalid in certain cases. Use default CPU device.
1121 VLOG(1) << s.error_message() << " using default CPU device.";
1122 }
1123
1124 // First we convert tensorflow::Tensor to RuntimeFallbackTensors.
1125 llvm::SmallVector<RuntimeFallbackTensor, 4> tfrt_tensor_args;
1126 tfrt_tensor_args.reserve(args.size());
1127 for (int i = 0; i < args.size(); ++i) {
1128 auto* av = args[i];
1129 auto tf_tensor = av->get<tensorflow::Tensor>();
1130
1131 tfrt::TensorMetadata md = tfd::GetTensorMetadata(tf_tensor);
1132 OwnedTensorHandle tensor_handle{tensorflow::TensorHandle::CreateLocalHandle(
1133 std::move(tf_tensor), /*d=*/device, /*op_device=*/device, eager_ctx)};
1134
1135 tfrt_tensor_args.push_back(
1136 RuntimeFallbackTensor(md.shape, md.dtype, std::move(tensor_handle)));
1137 }
1138
1139 llvm::SmallVector<tfrt::Tensor*, 4> tfrt_tensor_arg_ptrs;
1140 tfrt_tensor_arg_ptrs.reserve(args.size());
1141 for (auto& tensor : tfrt_tensor_args) tfrt_tensor_arg_ptrs.push_back(&tensor);
1142
1143 llvm::SmallVector<RCReference<tfrt::AsyncValue>, 4> tfrt_tensor_results;
1144 tfrt_tensor_results.resize(results.size());
1145
1146 auto chain = RuntimeFallbackExecute(
1147 exec_ctx, op_name.str().c_str(), device_name.c_str(),
1148 tfrt_tensor_arg_ptrs, tfrt::OpAttrsRef(op_attrs), tfrt_tensor_results);
1149
1150 if (op_chain) *op_chain = chain.CopyRef();
1151
1152 // After coreruntime returns, we check if there is any error. Currently we
1153 // assume runtime fallback execution is always synchronous.
1154 DCHECK(chain.IsAvailable());
1155 if (chain.IsError()) {
1156 EmitErrorAndSetInResults(exec_ctx, chain.GetError(), results.values());
1157 return;
1158 }
1159
1160 // Finally we convert the runtime fallback results, which are
1161 // RuntimeFallbackTensor, back to tensorflow::Tensor that is expected by the
1162 // BEF kernel.
1163 for (int i = 0; i < results.size(); ++i) {
1164 auto& runtime_fallback_tensor =
1165 tfrt_tensor_results[i]->get<RuntimeFallbackTensor>();
1166 const tensorflow::Tensor* tf_tensor = nullptr;
1167 tensorflow::Status s =
1168 runtime_fallback_tensor.GetTensorHandle()->Tensor(&tf_tensor);
1169 DCHECK(s.ok()) << s.ToString();
1170 results[i] =
1171 tfrt::MakeAvailableAsyncValueRef<tensorflow::Tensor>(host, *tf_tensor);
1172 }
1173 }
1174
AddRuntimeFallbackImplicitConversionKernel(Argument<tfrt::OpHandler * > op_handler,const ExecutionContext & exec_ctx)1175 Chain AddRuntimeFallbackImplicitConversionKernel(
1176 Argument<tfrt::OpHandler*> op_handler, const ExecutionContext& exec_ctx) {
1177 assert(op_handler.get()->GetName() == tfrt::CpuOpHandler::kName);
1178 tfrt::CpuOpHandler* cpu_op_handler =
1179 reinterpret_cast<tfrt::CpuOpHandler*>(op_handler.get());
1180 cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1181 DenseHostTensor::kTensorType);
1182 cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1183 tfrt::AnyScalarHostTensor::kTensorType);
1184 cpu_op_handler->AddImplicitConversion(RuntimeFallbackTensor::kTensorType,
1185 tfrt::StringHostTensor::kTensorType);
1186 return {};
1187 }
1188
CreateRuntimeFallbackOpHandlerKernel(Result<tfrt::OpHandler * > op_handler,StringAttribute tf_device_name,const ExecutionContext & exec_ctx)1189 void CreateRuntimeFallbackOpHandlerKernel(Result<tfrt::OpHandler*> op_handler,
1190 StringAttribute tf_device_name,
1191 const ExecutionContext& exec_ctx) {
1192 auto* runtime = tfrt::CoreRuntime::GetFromHostContext(exec_ctx.host());
1193 assert(runtime);
1194 auto op_handler_ptr =
1195 CreateRuntimeFallbackOpHandler(runtime, tf_device_name.get());
1196 assert(op_handler_ptr);
1197 op_handler.Emplace(op_handler_ptr.get());
1198 }
1199
ConvertTFRTTensorToTFTensorHandle(tfrt::Tensor * tensor)1200 static OwnedTensorHandle ConvertTFRTTensorToTFTensorHandle(
1201 tfrt::Tensor* tensor) {
1202 if (auto* dht = llvm::dyn_cast<tfrt::DenseHostTensor>(tensor)) {
1203 tensorflow::Tensor tensor =
1204 MoveHostBufferToTfTensor(dht->buffer(), dht->dtype(), dht->shape());
1205
1206 return OwnedTensorHandle{
1207 tensorflow::TensorHandle::CreateLocalHandle(tensor)};
1208 }
1209
1210 if (auto* sht = llvm::dyn_cast<tfrt::StringHostTensor>(tensor)) {
1211 tensorflow::Tensor tensor = CopyShtToTfTensor(*sht);
1212 return OwnedTensorHandle{
1213 tensorflow::TensorHandle::CreateLocalHandle(tensor)};
1214 }
1215
1216 llvm_unreachable("unsupported tensor type");
1217 }
1218
ConvertTFTensorHandleToTFRTTensor(OwnedTensorHandle tensor_handle,HostContext * host)1219 static llvm::Expected<tfrt::Value> ConvertTFTensorHandleToTFRTTensor(
1220 OwnedTensorHandle tensor_handle, HostContext* host) {
1221 tensorflow::Status status;
1222 // Resolve ensures Tensor is on host CPU.
1223 OwnedAbstractTensorInterface tensor_interface{
1224 tensor_handle->Resolve(&status)};
1225 if (!status.ok()) {
1226 return tfrt::MakeStringError("error resolving TensorHandle: ",
1227 status.error_message());
1228 }
1229 auto tf_dtype = tensor_interface->Type();
1230 if (tf_dtype == DT_STRING) {
1231 // TODO(tfrt-devs): Consider a more efficient way to pass string
1232 // tensors between TFRT and TF.
1233 auto string_host_tensor =
1234 CopyTfStringTensorToStringHostTensor(tensor_interface.get(), host);
1235 if (!string_host_tensor)
1236 return tfrt::MakeStringError(
1237 "error converting TF string tensor to tfrt::StringHostTensor: ",
1238 string_host_tensor.takeError());
1239 return tfrt::Value(std::move(*string_host_tensor));
1240 }
1241
1242 tfrt::TensorMetadata metadata(GetTfrtDtype(tf_dtype),
1243 GetShape(tensor_interface.get()));
1244
1245 CheckBoolCompatibility();
1246 void* data = tensor_interface->Data();
1247 size_t size = tensor_interface->ByteSize();
1248 // `tensor_interface` holds a reference on underlying Tensorflow buffer and is
1249 // held alive by HostBuffer deallocator lambda capture (a
1250 // llvm::unique_function), and it gets released when HostBuffer deallocator is
1251 // called and destroyed.
1252 auto host_buffer = HostBuffer::CreateFromExternal(
1253 data, size,
1254 [tensor_interface = std::move(tensor_interface)](void*, size_t) {});
1255
1256 tfrt::Value value;
1257 value.emplace<DenseHostTensor>(metadata, std::move(host_buffer));
1258 return std::move(value);
1259 }
1260
RegisterTfdDelegateKernels(tfrt::KernelRegistry * registry)1261 void RegisterTfdDelegateKernels(tfrt::KernelRegistry* registry) {
1262 registry->AddKernel("tfd.init_eager_context",
1263 TFRT_KERNEL(TfdInitEagerContext));
1264 registry->AddKernel("tfd.delegate_kernel",
1265 TFRT_KERNEL(RuntimeFallbackKernel));
1266 registry->AddKernel("tfd.move_dht_to_tft", TFRT_KERNEL(TfdMoveDHTToTFT));
1267 registry->AddKernel("tfd.convert_tft_to_dht",
1268 TFRT_KERNEL(TfdConvertTFTToDHT));
1269 registry->AddKernel("tfd.print_tft", TFRT_KERNEL(TfdPrintTFT));
1270 registry->AddKernel("tfrt_fallback_async.const_dense_tensor",
1271 TFRT_KERNEL(ConstDenseTensor));
1272 registry->AddKernel("tfrt_fallback_async.const_string_tensor",
1273 TFRT_KERNEL(ConstStringTensor));
1274 registry->AddKernel(
1275 "tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor",
1276 TFRT_KERNEL(CoreRTTensorHandleToFallbackTensor));
1277 registry->AddKernel(
1278 "tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle",
1279 TFRT_KERNEL(FallbackTensorToCoreRTTensorHandle));
1280
1281 // TODO(b/187106271): Move fallback kernels to fallback only libraries so that
1282 // we don't have to depend on or link in corert kernels.
1283 registry->AddKernel("tfrt_fallback_async.predicate", TFRT_KERNEL(Predicate));
1284 registry->AddKernel("tfrt_fallback_async.print_tensor",
1285 TFRT_KERNEL(PrintFallbackTensor));
1286 registry->AddKernel("corert.create_runtime_fallback_op_handler",
1287 TFRT_KERNEL(CreateRuntimeFallbackOpHandlerKernel));
1288 registry->AddKernel("corert.add_runtime_fallback_implicit_conversions",
1289 TFRT_KERNEL(AddRuntimeFallbackImplicitConversionKernel));
1290 }
1291
1292 } // namespace tfd
1293 } // namespace tensorflow
1294