1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
16
17 #include <optional>
18 #include <string>
19
20 #include "llvm/ADT/StringRef.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/framework/logging.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/threadpool_interface.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
31 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
32 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h"
33 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
34 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
35 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
36 #include "tensorflow/core/tfrt/fallback/cost_recorder.h"
37 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
38 #include "tensorflow/core/tfrt/utils/error_util.h"
39 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
40 #include "tensorflow/core/tfrt/utils/tensor_util.h"
41 #include "tfrt/core_runtime/execute_op_impl.h" // from @tf_runtime
42 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime
43 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
44 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
45 #include "tfrt/host_context/chain.h" // from @tf_runtime
46 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
47 #include "tfrt/host_context/kernel_registry.h" // from @tf_runtime
48 #include "tfrt/host_context/sync_kernel_frame.h" // from @tf_runtime
49 #include "tfrt/support/error_util.h" // from @tf_runtime
50 #include "tfrt/support/forward_decls.h" // from @tf_runtime
51 #include "tfrt/support/pointer_util.h" // from @tf_runtime
52 #include "tfrt/support/string_util.h" // from @tf_runtime
53 #include "tfrt/tensor/tensor.h" // from @tf_runtime
54
55 namespace tensorflow {
56 namespace tfd {
57 const char kOpKernelRunnerCacheResourceName[] =
58 "OpKernelRunnerCacheResourceName";
59
60 namespace {
61
62 using ::tensorflow::tfrt_stub::OpKernelRunner;
63 using ::tensorflow::tfrt_stub::OpKernelRunnerTable;
64 using ::tensorflow::tfrt_stub::OpKernelRunState;
65 using ::tfrt::AsyncValue;
66 using ::tfrt::AsyncValueRef;
67 using ::tfrt::Chain;
68 using ::tfrt::OpAttrsRef;
69 using ::tfrt::RCReference;
70 using ::tfrt::string_view;
71
72 constexpr char kOpKernelRunnerTableResourceName[] =
73 "OpKernelRunnerTableResourceName";
74
75 constexpr char kFallbackResourceArray[] = "FallbackResourceArray";
76
KernelFallbackEmitError(const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState * fallback_request_state,tfrt::string_view op_name,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tensorflow::Status & status)77 void KernelFallbackEmitError(
78 const tfrt::ExecutionContext& exec_ctx,
79 const KernelFallbackCompatRequestState* fallback_request_state,
80 tfrt::string_view op_name, tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
81 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
82 const tensorflow::Status& status) {
83 // Set all results to error, with the correct TFRT error code according to the
84 // error propagated from runtime fallback execution.
85 auto model_info =
86 fallback_request_state == nullptr
87 ? "(missing model info) "
88 : tfrt::StrCat(
89 fallback_request_state->session_metadata().name(), " (",
90 fallback_request_state->session_metadata().version(), ") ");
91 auto error = EmitErrorAsync(
92 exec_ctx,
93 tfrt::StrCat(model_info, "error running kernel fallback kernel ", op_name,
94 ": ", status.error_message()),
95 tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
96 std::fill(results.begin(), results.end(), error);
97 if (op_chain) *op_chain = std::move(error);
98 }
99
GetDefaultRunner()100 std::function<void(std::function<void()>)>* GetDefaultRunner() {
101 static auto* const default_runner =
102 new std::function<void(std::function<void()>)>(
103 [](const std::function<void()>& f) { f(); });
104 return default_runner;
105 }
106
107 } // namespace
108
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,const tensorflow::DeviceMgr * device_manager,const tensorflow::ProcessFunctionLibraryRuntime * pflr,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<SessionMetadata> & model_metadata,std::function<void (std::function<void ()>)> * runner)109 Status SetUpKernelFallbackCompatRequestContext(
110 tfrt::RequestContextBuilder* builder,
111 const tensorflow::DeviceMgr* device_manager,
112 const tensorflow::ProcessFunctionLibraryRuntime* pflr,
113 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
114 const absl::optional<SessionMetadata>& model_metadata,
115 std::function<void(std::function<void()>)>* runner) {
116 DCHECK(builder);
117 DCHECK(device_manager);
118 DCHECK(pflr);
119
120 auto* runner_table =
121 builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
122 kOpKernelRunnerTableResourceName);
123
124 auto* resource_array =
125 builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
126 kFallbackResourceArray);
127
128 builder->context_data().emplace<KernelFallbackCompatRequestState>(
129 runner ? runner : GetDefaultRunner(), device_manager, builder->id(),
130 runner_table, resource_array, user_intra_op_threadpool, model_metadata,
131 pflr);
132
133 return OkStatus();
134 }
135
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,OpKernelRunnerTable * runner_table,tensorflow::EagerContext * eager_context,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<SessionMetadata> & model_metadata)136 Status SetUpKernelFallbackCompatRequestContext(
137 tfrt::RequestContextBuilder* builder, OpKernelRunnerTable* runner_table,
138 tensorflow::EagerContext* eager_context,
139 tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
140 const absl::optional<SessionMetadata>& model_metadata) {
141 auto* resource_array =
142 builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
143 kFallbackResourceArray);
144
145 if (runner_table == nullptr)
146 runner_table =
147 builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
148 kOpKernelRunnerTableResourceName);
149
150 auto step_id = builder->id();
151
152 auto& fallback_request_state =
153 builder->context_data().emplace<KernelFallbackCompatRequestState>(
154 GetDefaultRunner(), eager_context->local_device_mgr(), step_id,
155 tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
156 eager_context->StepContainer()},
157 eager_context->GetCollectiveExecutorHandle(),
158 tensorflow::core::RefCountPtr<tensorflow::Rendezvous>(
159 eager_context->RendezvousCreator()(step_id)),
160 runner_table, resource_array, user_intra_op_threadpool,
161 model_metadata, eager_context->pflr());
162
163 fallback_request_state.set_log_device_placement(
164 eager_context->LogDevicePlacement());
165
166 return OkStatus();
167 }
168
169 static llvm::Expected<gtl::InlinedVector<tensorflow::Tensor, 4>>
ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor * > arguments,const tfrt::ExecutionContext & exec_ctx)170 ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor*> arguments,
171 const tfrt::ExecutionContext& exec_ctx) {
172 gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
173 input_tf_tensors.reserve(arguments.size());
174 for (tfrt::Tensor* argument : arguments) {
175 auto expected_tf_tensor = TFRTTensorToTFTensor(*argument, exec_ctx.host());
176 if (!expected_tf_tensor) {
177 return tfrt::MakeStringError(
178 tfrt::StrCat(expected_tf_tensor.takeError()));
179 }
180 input_tf_tensors.push_back(std::move(expected_tf_tensor.get()));
181 }
182
183 return input_tf_tensors;
184 }
185
ValidateInputTypes(tfrt::string_view op_name,const gtl::InlinedVector<tensorflow::Tensor,4> & input_tf_tensors,const DataTypeVector & input_types)186 static Status ValidateInputTypes(
187 tfrt::string_view op_name,
188 const gtl::InlinedVector<tensorflow::Tensor, 4>& input_tf_tensors,
189 const DataTypeVector& input_types) {
190 const size_t n_inputs = input_tf_tensors.size();
191
192 if (input_types.size() != n_inputs) {
193 return tensorflow::errors::InvalidArgument("expected ", input_types.size(),
194 " inputs, got ", n_inputs);
195 }
196
197 for (size_t i = 0; i < n_inputs; ++i) {
198 if (input_tf_tensors[i].dtype() != input_types[i]) {
199 return tensorflow::errors::InvalidArgument(
200 "cannot compute ", op_name.str(), " as input #", i, "(zero-based)",
201 " was expected to be a ", DataTypeString(input_types[i]),
202 " tensor but is a ", DataTypeString(input_tf_tensors[i].dtype()),
203 " tensor");
204 }
205 }
206
207 return OkStatus();
208 }
209
210 namespace {
211
212 // Keep states needed by kernel execution in a thread local storage to avoid
213 // repeated reallocation and destruction of them.
GetThreadLocalOpKernelRunState()214 OpKernelRunState& GetThreadLocalOpKernelRunState() {
215 thread_local OpKernelRunState run_state;
216 return run_state;
217 }
218
219 } // namespace
220
221 // Execute a tensorflow::OpKernel Asynchronously. `kernel_runner` and
222 // `input_tf_tensors` are expected to be alive during the call to this function.
223 // Set result AsyncValues in `results` and return a Chain that indicates the
224 // execution completion of error otherwise.
225 template <typename TensorType>
KernelFallbackExecuteCompatAsyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)226 static void KernelFallbackExecuteCompatAsyncInternal(
227 const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
228 const OpKernelRunner& kernel_runner,
229 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
230 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
231 auto chain =
232 tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>(exec_ctx.host());
233 if (op_chain) *op_chain = chain.CopyRef();
234
235 // Allocate unconstructed result tensors and set them in the output `results`.
236 llvm::SmallVector<AsyncValueRef<TensorType>, 4> result_refs;
237 result_refs.reserve(results.size());
238 for (auto& result : results) {
239 result_refs.emplace_back(
240 tfrt::MakeUnconstructedAsyncValueRef<TensorType>(exec_ctx.host()));
241 result = result_refs.back().CopyRef();
242 }
243
244 struct AsyncState {
245 explicit AsyncState(const OpKernelRunState& rs, int num_outputs)
246 : run_state(rs.input_tf_tensor_values, rs.params),
247 context(&run_state.params, num_outputs) {}
248
249 OpKernelRunState run_state;
250 OpKernelContext context;
251
252 tfrt::AsyncValueRef<tfrt::Chain> chain;
253 llvm::SmallVector<tfrt::AsyncValueRef<TensorType>, 4> result_refs;
254 };
255
256 DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
257 auto async_state = std::make_shared<AsyncState>(*run_state, results.size());
258 async_state->chain = chain.CopyRef();
259 async_state->result_refs = std::move(result_refs);
260
261 auto* context_ptr = &async_state->context;
262
263 auto done_callback = [async_state = std::move(async_state), exec_ctx]() {
264 auto& context = async_state->context;
265
266 if (!context.status().ok()) {
267 auto diag = tfrt::EmitError(
268 exec_ctx,
269 {tfrt::StrCat("error running kernel fallback kernel ",
270 context.op_kernel().name(), ": ",
271 context.status().error_message())},
272 tfrt::ConvertTfErrorCodeToTfrtErrorCode(context.status()));
273 for (auto& result : async_state->result_refs) result.SetError(diag);
274 async_state->chain.SetError(diag);
275 return;
276 }
277
278 // Set payload and mark async values available in TFRT's thread.
279 tfrt::EnqueueWork(exec_ctx, [async_state = std::move(async_state)]() {
280 auto& context = async_state->context;
281 for (int i = 0; i < context.num_outputs(); ++i) {
282 async_state->result_refs[i].emplace(
283 std::move(*context.mutable_output(i)));
284 }
285 async_state->chain.emplace();
286 });
287 };
288
289 kernel_runner.RunAsync(context_ptr, std::move(done_callback));
290 }
291
292 // Execute a tensorflow::OpKernel synchronously. `kernel_runner` and
293 // `input_tf_tensors` are expected to be alive during the call to this function.
294 // Set result AsyncValues in `results` and return OK status on successfully
295 // finishing the execution. TensorType is expected to be convert-constructible
296 // from tensorflow::Tensor.
297 template <typename TensorType>
KernelFallbackExecuteCompatSyncInternal(const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState * fallback_request_state,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)298 static void KernelFallbackExecuteCompatSyncInternal(
299 const tfrt::ExecutionContext& exec_ctx,
300 const KernelFallbackCompatRequestState* fallback_request_state,
301 OpKernelRunState* run_state, const OpKernelRunner& kernel_runner,
302 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
303 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
304 DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
305 OpKernelContext context(&run_state->params, results.size());
306 kernel_runner.Run(&context);
307
308 if (!context.status().ok()) {
309 KernelFallbackEmitError(exec_ctx, fallback_request_state,
310 kernel_runner.op_kernel()->name(), op_chain,
311 results, context.status());
312 return;
313 }
314
315 for (int i = 0; i < context.num_outputs(); ++i) {
316 results[i] = tfrt::MakeAvailableAsyncValueRef<TensorType>(
317 std::move(*context.mutable_output(i)));
318 }
319
320 if (op_chain) *op_chain = tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
321 }
322
PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef & attrs)323 static std::string PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef& attrs) {
324 std::string str;
325 llvm::raw_string_ostream ss(str);
326 attrs.Print(ss);
327 ss.flush();
328 return str;
329 }
330
KernelFallbackExecuteCompatCoreRuntimeDispatch(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::string_view device_name,llvm::ArrayRef<tfrt::Tensor * > arguments,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const KernelFallbackCompatRequestState & fallback_request_state,const OpKernelRunner & op_kernel_runner)331 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackExecuteCompatCoreRuntimeDispatch(
332 const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
333 tfrt::string_view device_name, llvm::ArrayRef<tfrt::Tensor*> arguments,
334 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
335 const KernelFallbackCompatRequestState& fallback_request_state,
336 const OpKernelRunner& op_kernel_runner) {
337 auto op_chain = tfrt::GetReadyChain();
338 tensorflow::Status status;
339
340 auto expected_input_tf_tensors = ConvertInputTensors(arguments, exec_ctx);
341 if (!expected_input_tf_tensors) {
342 status = tensorflow::errors::Internal(
343 tfrt::StrCat(expected_input_tf_tensors.takeError()));
344 KernelFallbackEmitError(exec_ctx, &fallback_request_state, op_name,
345 &op_chain, results, status);
346 return op_chain;
347 }
348
349 auto& run_state = GetThreadLocalOpKernelRunState();
350 auto clean_up_inputs =
351 gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
352
353 auto& input_tf_tensors = run_state.input_tf_tensors;
354 input_tf_tensors = std::move(expected_input_tf_tensors.get());
355
356 // Check if input tensor dtypes are valid.
357 status = ValidateInputTypes(op_name, input_tf_tensors,
358 op_kernel_runner.op_kernel()->input_types());
359
360 // TODO(b/176997538): Skip checking dtypes for tf._BatchFunctionFallback op
361 // due to b/176997538. Remove the skipping once the SavedModel lowering
362 // problem is fixed.
363 if (!status.ok() && !op_name.equals("_BatchFunctionFallback")) {
364 KernelFallbackEmitError(exec_ctx, &fallback_request_state, op_name,
365 &op_chain, results, status);
366 return op_chain;
367 }
368
369 auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
370 input_tf_tensor_values.resize(input_tf_tensors.size());
371 for (int i = 0; i < input_tf_tensors.size(); ++i) {
372 input_tf_tensor_values[i].tensor = &input_tf_tensors[i];
373 }
374
375 auto* device =
376 GetDeviceFromFallbackState(fallback_request_state, op_kernel_runner);
377
378 SetUpParams(op_kernel_runner, fallback_request_state, device, run_state);
379
380 if (op_kernel_runner.IsAsync()) {
381 KernelFallbackExecuteCompatAsyncInternal<KernelFallbackTensor>(
382 exec_ctx, &run_state, op_kernel_runner, &op_chain, results);
383 } else {
384 KernelFallbackExecuteCompatSyncInternal<KernelFallbackTensor>(
385 exec_ctx, &fallback_request_state, &run_state, op_kernel_runner,
386 &op_chain, results);
387 }
388
389 return op_chain;
390 }
391
StripTfPrefix(tfrt::string_view op_name)392 static absl::string_view StripTfPrefix(tfrt::string_view op_name) {
393 return absl::StripPrefix(ToAbslStringView(op_name), "tf.");
394 }
395
396 // Generate metadata for an execution op event
GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue * > args,const tfrt::ExecutionContext & exec_ctx,const OpKernelRunner & kernel_runner)397 std::string GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue*> args,
398 const tfrt::ExecutionContext& exec_ctx,
399 const OpKernelRunner& kernel_runner) {
400 auto request_id = exec_ctx.request_ctx()->id();
401 // Get Long Name
402 auto debug_info = exec_ctx.location().GetDebugInfo();
403 auto long_name = debug_info.has_value() ? debug_info.getValue().info : "";
404
405 if (!profiler::TfOpDetailsEnabled()) {
406 return profiler::TraceMeEncode(
407 {{"id", request_id}, {"long_name", ToAbslStringView(long_name)}});
408 }
409
410 // Get Input Tensors
411 std::string input_string;
412 llvm::raw_string_ostream input_string_stream(input_string);
413
414 for (size_t i = 0; i < args.size(); ++i) {
415 const auto& tensor = args[i]->get<Tensor>();
416 input_string_stream << DataTypeString(tensor.dtype())
417 << tensor.shape().DebugString();
418 input_string_stream << ";";
419 }
420
421 // Get Attributes
422 std::string attr_string;
423 llvm::raw_string_ostream attr_string_stream(attr_string);
424
425 for (const auto& entry : kernel_runner.op_kernel()->def().attr()) {
426 attr_string_stream << entry.first << ": {" << entry.second.DebugString();
427 if (!attr_string.empty() && attr_string[attr_string.size() - 1] == '\n') {
428 attr_string[attr_string.size() - 1] = '}';
429 }
430 attr_string_stream << ";\n";
431 }
432
433 return profiler::TraceMeEncode({
434 {"id", request_id},
435 {"long_name", ToAbslStringView(long_name)},
436 {"inputs", input_string},
437 {"attributes", attr_string},
438 });
439 }
440
441 namespace {
442
443 class FallbackKernelAttributeFrame {
444 public:
FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame * frame)445 explicit FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame* frame)
446 : frame_(frame) {
447 DCHECK(frame_);
448 }
449
device() const450 tfrt::StringAttr device() const {
451 return tfrt::StringAttr(frame_->GetAttribute(kDeviceAttrPosition));
452 }
453
op_attr() const454 tfrt::AggregateAttr op_attr() const {
455 return tfrt::AggregateAttr(frame_->GetAttribute(kOpAttrPosition));
456 }
457
op_func_attr() const458 tfrt::AggregateAttr op_func_attr() const {
459 return tfrt::AggregateAttr(frame_->GetAttribute(kOpFuncAttrPosition));
460 }
461
op_key() const462 tfrt::I64Attr op_key() const {
463 return tfrt::I64Attr(frame_->GetAttribute(kOpKeyAttrPosition));
464 }
465
op_name() const466 tfrt::StringAttr op_name() const {
467 return tfrt::StringAttr(frame_->GetAttribute(kOpNameAttrPosition));
468 }
469
470 private:
471 static constexpr int kDeviceAttrPosition = 0;
472 static constexpr int kOpAttrPosition = 1;
473 static constexpr int kOpFuncAttrPosition = 2;
474 static constexpr int kOpKeyAttrPosition = 3;
475 static constexpr int kOpNameAttrPosition = 4;
476
477 tfrt::AsyncKernelFrame* frame_ = nullptr;
478 };
479
480 // The BEF kernel for kernel fallback compat mode. The arguments and results are
481 // expected to tensorflow::tfrt_stub::FallbackTensor.
KernelFallbackExecuteOpInternal(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState & fallback_request_state,const OpKernelRunner & kernel_runner,bool is_async,tensorflow::Device * device)482 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOpInternal(
483 llvm::ArrayRef<tfrt::AsyncValue*> args,
484 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
485 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
486 const FallbackKernelAttributeFrame& frame,
487 const tfrt::ExecutionContext& exec_ctx,
488 const KernelFallbackCompatRequestState& fallback_request_state,
489 const OpKernelRunner& kernel_runner, bool is_async,
490 tensorflow::Device* device) {
491 tensorflow::profiler::TraceMe trace_me([&]() -> std::string {
492 if (kernel_runner.op_kernel()) {
493 return tensorflow::profiler::TraceMeOp(
494 kernel_runner.op_kernel()->name_view(),
495 kernel_runner.op_kernel()->type_string_view());
496 }
497 return std::string(ToAbslStringView(frame.op_name().GetValue()));
498 });
499
500 trace_me.AppendMetadata(
501 [&]() { return GetTracingMetadata(args, exec_ctx, kernel_runner); });
502
503 if (fallback_request_state.log_device_placement() || VLOG_IS_ON(1)) {
504 string msg =
505 strings::StrCat("Executing op ", frame.op_name().GetValue().str(),
506 " in device ", frame.device().GetValue().str());
507 if (!logging::LogToListeners(msg)) {
508 LOG(INFO) << msg;
509 }
510 }
511
512 auto& run_state = GetThreadLocalOpKernelRunState();
513 auto clean_up_inputs =
514 gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
515
516 // Prepare the input tensors.
517 auto& input_tf_tensors = run_state.input_tf_tensors;
518 auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
519 DCHECK(input_tf_tensors.empty());
520 input_tf_tensor_values.resize(args.size());
521 for (int i = 0; i < args.size(); ++i) {
522 auto* arg = args[i];
523 auto& fallback_tensor = arg->get<tensorflow::tfrt_stub::FallbackTensor>();
524 // If the argument is immutable or unique, we can just keep the reference
525 // without copying that invovles expensive atomic reference counting. And if
526 // the argument is unique but mutable, then tensorflow optimizations like
527 // buffer forwarding can be utilized. Otherwise, we conservatively copy the
528 // tensor.
529 if (!fallback_tensor.is_immutable() && !arg->IsUnique()) {
530 input_tf_tensors.push_back(fallback_tensor.tensor());
531 }
532 input_tf_tensor_values[i].tensor = &fallback_tensor.tensor();
533 }
534
535 SetUpParams(kernel_runner, fallback_request_state, device, run_state);
536
537 bool is_cost_measurement_enabled =
538 exec_ctx.request_ctx()->IsCostMeasurementEnabled();
539 auto run_start_time =
540 is_cost_measurement_enabled ? Env::Default()->NowMicros() : 0;
541 if (is_async) {
542 KernelFallbackExecuteCompatAsyncInternal<
543 tensorflow::tfrt_stub::FallbackTensor>(
544 exec_ctx, &run_state, kernel_runner, op_chain, results);
545 } else {
546 KernelFallbackExecuteCompatSyncInternal<
547 tensorflow::tfrt_stub::FallbackTensor>(
548 exec_ctx, &fallback_request_state, &run_state, kernel_runner, op_chain,
549 results);
550 }
551 if (is_cost_measurement_enabled) {
552 op_chain->AndThen([run_start_time, exec_ctx, frame] {
553 // Adds 1 to make sure it's a positive integer.
554 auto execution_time = Env::Default()->NowMicros() - run_start_time + 1;
555 // Adds op_key as a suffix to distinguish the same operation with
556 // different shape.
557 exec_ctx.host()
558 ->GetOrCreateSharedContext<tensorflow::tfrt_stub::CostRecorder>()
559 .RecordCost(frame.op_key().GetValue(), execution_time);
560 });
561 }
562 }
563
KernelFallbackExecuteOp(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx)564 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOp(
565 llvm::ArrayRef<tfrt::AsyncValue*> args,
566 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
567 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
568 const FallbackKernelAttributeFrame& frame,
569 const tfrt::ExecutionContext& exec_ctx) {
570 const auto* fallback_request_state =
571 exec_ctx.request_ctx()
572 ->GetDataIfExists<KernelFallbackCompatRequestState>();
573 if (!fallback_request_state) {
574 KernelFallbackEmitError(
575 exec_ctx, /*fallback_request_state=*/nullptr,
576 frame.op_name().GetValue(), op_chain, results,
577 tensorflow::errors::NotFound(
578 "KernelFallbackCompatRequestState not found in RequestContext."));
579 return;
580 }
581
582 auto* runner_table = fallback_request_state->runner_table();
583 DCHECK(runner_table);
584
585 auto* kernel_runner = runner_table->Get(frame.op_key().GetValue());
586 DCHECK(kernel_runner);
587 DCHECK_EQ(kernel_runner->op_kernel()->name(),
588 StripTfPrefix(frame.op_name().GetValue()));
589
590 auto* device =
591 GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner);
592
593 KernelFallbackExecuteOpInternal(args, results, op_chain, frame, exec_ctx,
594 *fallback_request_state, *kernel_runner,
595 kernel_runner->IsAsync(), device);
596 }
597
598 // The BEF kernel for creating tensorflow::OpKernel to be used in kernel
599 // fallback compat mode.
KernelFallbackCreateOp(const tfrt::Chain & in_ch,tfrt::StringAttr device,tfrt::I64Attr num_args,tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tfrt::I64Attr op_key,tfrt::StringAttr op_name_attr,const tfrt::ExecutionContext & exec_ctx)600 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackCreateOp(
601 const tfrt::Chain& in_ch, tfrt::StringAttr device, tfrt::I64Attr num_args,
602 tfrt::AggregateAttr op_attr_array, tfrt::AggregateAttr op_func_attr_array,
603 tfrt::I64Attr op_key, tfrt::StringAttr op_name_attr,
604 const tfrt::ExecutionContext& exec_ctx) {
605 const auto* fallback_request_state =
606 exec_ctx.request_ctx()
607 ->GetDataIfExists<KernelFallbackCompatRequestState>();
608 if (!fallback_request_state) {
609 return tfrt::EmitErrorAsync(
610 exec_ctx,
611 "KernelFallbackCompatRequestState not found in RequestContext.");
612 }
613
614 auto* runner_table = fallback_request_state->runner_table();
615 DCHECK(runner_table);
616
617 auto attr_builder = [op_attr_array, op_func_attr_array](
618 tensorflow::AttrValueMap* attr_value_map) {
619 return SetUpAttrValueMap(op_attr_array, op_func_attr_array, attr_value_map);
620 };
621
622 auto op_name = StripTfPrefix(op_name_attr.GetValue());
623
624 auto statusor_runner = OpKernelRunner::Create(
625 op_name, ToAbslStringView(device.GetValue()), num_args.GetValue(),
626 attr_builder, fallback_request_state->device_manager(),
627 fallback_request_state->process_function_library_runtime());
628 if (!statusor_runner.ok())
629 return tfrt::EmitErrorAsync(
630 exec_ctx, statusor_runner.status().error_message(),
631 tfrt::ConvertTfErrorCodeToTfrtErrorCode(statusor_runner.status()));
632
633 if (!runner_table->Insert(op_key.GetValue(),
634 std::move(statusor_runner).ValueOrDie())) {
635 return tfrt::EmitErrorAsync(
636 exec_ctx,
637 absl::StrCat("KernelFallbackCreateOp: OpKernelRunner already exists: ",
638 op_name_attr.GetValue().str()));
639 }
640
641 return tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
642 }
643
644 // FallbackSetResource is the fallback kernel that sets the tensor value in the
645 // fallback's resource array.
FallbackSetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::StringAttr device,tfrt::I64Attr index_attr,const tfrt::ExecutionContext & exec_ctx)646 llvm::Expected<tfrt::Chain> FallbackSetResource(
647 tfrt::Argument<tfrt::Chain> in_ch,
648 tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
649 tfrt::StringAttr device, tfrt::I64Attr index_attr,
650 const tfrt::ExecutionContext& exec_ctx) {
651 const auto* fallback_request_state =
652 exec_ctx.request_ctx()
653 ->GetDataIfExists<KernelFallbackCompatRequestState>();
654 if (!fallback_request_state) {
655 return tfrt::MakeStringError(
656 "KernelFallbackCompatRequestState not found in RequestContext.");
657 }
658
659 auto* resource_array = fallback_request_state->resource_array();
660 DCHECK(resource_array);
661
662 int64_t index = index_attr.GetValue();
663
664 // Setting the resource tensor to be immutable, so that we don't need
665 // reference counting on it and that it cannot be buffer-forwarded.
666 resource_array->SetResource(
667 index,
668 tensorflow::tfrt_stub::ImmutableTensor::Create(arg.get().tensor()));
669
670 return tfrt::Chain();
671 }
672
673 // FallbackGetResource is the fallback kernel that retrieves the tensor value in
674 // the fallback's resource array.
FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Result<tfrt::Chain> out_ch,tfrt::RemainingResults results,tfrt::StringAttr device,tfrt::ArrayAttr indices_attr,const tfrt::ExecutionContext & exec_ctx)675 void FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,
676 tfrt::Result<tfrt::Chain> out_ch,
677 tfrt::RemainingResults results,
678 tfrt::StringAttr device, tfrt::ArrayAttr indices_attr,
679 const tfrt::ExecutionContext& exec_ctx) {
680 tensorflow::profiler::TraceMe trace_me("tfrt_fallback_async.get_resource");
681 trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
682 return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
683 });
684
685 const auto* fallback_request_state =
686 exec_ctx.request_ctx()
687 ->GetDataIfExists<KernelFallbackCompatRequestState>();
688 if (!fallback_request_state) {
689 tfrt::RCReference<tfrt::AsyncValue> error = tfrt::EmitErrorAsync(
690 exec_ctx,
691 "KernelFallbackCompatRequestState not found in RequestContext.");
692 out_ch.Set(std::move(error));
693 return;
694 }
695
696 auto* resource_array = fallback_request_state->resource_array();
697 DCHECK(resource_array);
698
699 llvm::ArrayRef<int64_t> indices = indices_attr.GetValue<int64_t>();
700
701 for (int i = 0; i < indices.size(); ++i) {
702 results[i] = tfrt::FormRef(resource_array->GetResource(indices[i]));
703 }
704
705 out_ch.Set(in_ch);
706 }
707
708 // The implementation of tfrt_fallback_async.executeop kernel. It executes a
709 // non-side-effecting TF op with the name of `op_name` in fallback. All relevant
710 // TF attributes are passed in `op_attr_array`.
FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame * frame)711 void FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame* frame) {
712 FallbackKernelAttributeFrame attr_frame(frame);
713 #ifndef NDEBUG
714 frame->GetExecutionContext()
715 .host()
716 ->GetOrCreateSharedContext<OpLogger>()
717 .LogOp(attr_frame.op_name().GetValue());
718 #endif
719 // Create op_chain only when cost measurement is enabled. It is used for
720 // measuring async op's actual latency.
721 if (frame->GetExecutionContext().request_ctx()->IsCostMeasurementEnabled()) {
722 auto op_chain = tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>();
723 KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
724 &op_chain, attr_frame,
725 frame->GetExecutionContext());
726 } else {
727 KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
728 /*op_chain=*/nullptr, attr_frame,
729 frame->GetExecutionContext());
730 }
731 }
732
733 // The implementation of tfrt_fallback_async.executeop.seq kernel. It executes a
734 // side-effecting TF op with the name of `op_name` in fallback. All relevant
735 // TF attributes are passed in `op_attr_array`. `in_op_chain` and `out_op_chain`
736 // are used for side-effect visibility.
FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame * frame)737 void FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame* frame) {
738 auto all_args = frame->GetArguments();
739 DCHECK_GT(all_args.size(), 0);
740 tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(all_args[0]));
741 llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
742
743 auto all_results = frame->GetResults();
744 DCHECK_GT(all_results.size(), 0);
745 auto& out_op_chain = all_results[0];
746 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results =
747 all_results.drop_front();
748
749 KernelFallbackExecuteOp(args, results, &op_chain,
750 FallbackKernelAttributeFrame(frame),
751 frame->GetExecutionContext());
752 out_op_chain = std::move(op_chain);
753 }
754
755 class DeviceWithCustomAllocator : public tensorflow::Device {
756 public:
DeviceWithCustomAllocator(tensorflow::Device * device,tensorflow::Allocator * allocator)757 DeviceWithCustomAllocator(tensorflow::Device* device,
758 tensorflow::Allocator* allocator)
759 : Device(device->env(), device->attributes()),
760 device_(device),
761 allocator_(allocator) {
762 DCHECK(device_);
763 DCHECK(allocator_);
764 }
765
GetAllocator(AllocatorAttributes attr)766 Allocator* GetAllocator(AllocatorAttributes attr) override {
767 return allocator_;
768 }
769
UnderlyingDevice() const770 const DeviceBase* UnderlyingDevice() const override {
771 return device_->UnderlyingDevice();
772 }
UnderlyingDevice()773 DeviceBase* UnderlyingDevice() override {
774 return device_->UnderlyingDevice();
775 }
776
tensorflow_cpu_worker_threads() const777 const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
778 return device_->tensorflow_cpu_worker_threads();
779 }
780
GetScopedAllocator(AllocatorAttributes attr,int64_t step_id)781 Allocator* GetScopedAllocator(AllocatorAttributes attr,
782 int64_t step_id) override {
783 return device_->GetScopedAllocator(attr, step_id);
784 }
785
GetScopedAllocatorMgr() const786 ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
787 return device_->GetScopedAllocatorMgr();
788 }
789
eigen_cpu_device()790 const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
791 return device_->eigen_cpu_device();
792 }
793
tensorflow_device_thread_pool()794 thread::ThreadPool* tensorflow_device_thread_pool() override {
795 return device_->tensorflow_device_thread_pool();
796 }
797
has_eigen_cpu_device() const798 bool has_eigen_cpu_device() const override {
799 return device_->has_eigen_cpu_device();
800 }
801
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)802 Status MakeTensorFromProto(const TensorProto& tensor_proto,
803 const AllocatorAttributes alloc_attrs,
804 Tensor* tensor) override {
805 return device_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
806 }
807
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)808 void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
809 const DeviceContext* device_context,
810 StatusCallback done) override {
811 device_->CopyTensorInSameDevice(input_tensor, output_tensor, device_context,
812 std::move(done));
813 }
814
Sync()815 Status Sync() override { return device_->Sync(); }
816
817 // Returns the resource manager associated w/ this device.
resource_manager()818 ResourceMgr* resource_manager() override {
819 return device_->resource_manager();
820 }
821
822 private:
823 tensorflow::Device* device_ = nullptr;
824 tensorflow::Allocator* allocator_ = nullptr;
825 };
826
KernelFallbackExecuteOpCustomAllocatorInternal(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const tfrt::ExecutionContext & exec_ctx,const FallbackKernelAttributeFrame & attr_frame)827 void KernelFallbackExecuteOpCustomAllocatorInternal(
828 llvm::ArrayRef<tfrt::AsyncValue*> args,
829 llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
830 tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
831 const tfrt::ExecutionContext& exec_ctx,
832 const FallbackKernelAttributeFrame& attr_frame) {
833 const auto* fallback_request_state =
834 exec_ctx.request_ctx()
835 ->GetDataIfExists<KernelFallbackCompatRequestState>();
836 if (!fallback_request_state) {
837 KernelFallbackEmitError(
838 exec_ctx, /*fallback_request_state=*/nullptr,
839 attr_frame.op_name().GetValue(), op_chain, results,
840 tensorflow::errors::NotFound(
841 "KernelFallbackCompatRequestState not found in RequestContext."));
842 return;
843 }
844
845 auto* runner_table = fallback_request_state->runner_table();
846 DCHECK(runner_table);
847
848 auto* kernel_runner = runner_table->Get(attr_frame.op_key().GetValue());
849 DCHECK(kernel_runner);
850 DCHECK_EQ(kernel_runner->op_kernel()->name(),
851 StripTfPrefix(attr_frame.op_name().GetValue()));
852
853 DCHECK_GT(args.size(), 0);
854 auto* allocator = args.front()->get<tensorflow::Allocator*>();
855 args = args.drop_front();
856
857 DeviceWithCustomAllocator device_with_custom_allocator(
858 GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner),
859 allocator);
860
861 // Different from FallbackAsyncExecuteOp, async execution is not allowed due
862 // to the lifetime of the wrapper device cannot be extended.
863 //
864 // TODO(b/200575143): Consider allowing async execution and extending the
865 // lifetime of the wrapping device.
866 KernelFallbackExecuteOpInternal(args, results,
867 /*op_chain=*/op_chain, attr_frame, exec_ctx,
868 *fallback_request_state, *kernel_runner,
869 /*is_async=*/false,
870 &device_with_custom_allocator);
871 }
872
FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame * frame)873 void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
874 auto args = frame->GetArguments();
875 auto results = frame->GetResults();
876 FallbackKernelAttributeFrame attr_frame(frame);
877 KernelFallbackExecuteOpCustomAllocatorInternal(
878 args, results, /*op_chain=*/nullptr, frame->GetExecutionContext(),
879 attr_frame);
880 }
881
FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame * frame)882 void FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame* frame) {
883 auto args = frame->GetArguments();
884 DCHECK_GT(args.size(), 0);
885 tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(args.front()));
886 args = args.drop_front();
887
888 auto results = frame->GetResults();
889 DCHECK_GT(results.size(), 0);
890 auto& out_op_chain = results.front();
891 results = results.drop_front();
892
893 FallbackKernelAttributeFrame attr_frame(frame);
894 KernelFallbackExecuteOpCustomAllocatorInternal(
895 args, results, &op_chain, frame->GetExecutionContext(), attr_frame);
896
897 out_op_chain = std::move(op_chain);
898 }
899
FallbackCopyTensorIfSmall(tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::RemainingResults results)900 void FallbackCopyTensorIfSmall(
901 tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
902 tfrt::RemainingResults results) {
903 const auto& fallback_tensor = arg.get();
904 const auto& tensor = fallback_tensor.tensor();
905
906 if (!fallback_tensor.is_immutable()) {
907 // Create a new TensorBuffer which contains a new atomic counter for each
908 // result, to avoid downstream threads contending the original atomic
909 // counter.
910 for (int i = 0; i < results.size(); ++i) {
911 auto immutable_tensor =
912 tensorflow::tfrt_stub::ImmutableTensor::Create(tensor);
913 results[i] = tfrt::MakeAvailableAsyncValueRef<
914 tensorflow::tfrt_stub::FallbackTensor>(
915 std::move(immutable_tensor.tensor()));
916 }
917 } else {
918 // For immutable tensors, we just need to copy the pointer. Note that we
919 // still create a new AsyncValue in this case, to avoid atomic contention on
920 // AsyncValue's refcount.
921 for (int i = 0; i < results.size(); ++i) {
922 results[i] = tfrt::MakeAvailableAsyncValueRef<
923 tensorflow::tfrt_stub::FallbackTensor>(fallback_tensor);
924 }
925 }
926 }
927
ConstTensorProto(tfrt::StringAttr serialized_tensor_proto)928 llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstTensorProto(
929 tfrt::StringAttr serialized_tensor_proto) {
930 tensorflow::TensorProto tensor_proto;
931 if (!tensor_proto.ParseFromString(serialized_tensor_proto.GetValue().str())) {
932 return tfrt::MakeStringError("Failed to parse const tensor proto");
933 }
934
935 tensorflow::Tensor tensor;
936 if (!tensor.FromProto(tensor_proto)) {
937 return tfrt::MakeStringError("Failed to create tensor from tensor proto: ",
938 tensor_proto.ShortDebugString());
939 }
940
941 return tensorflow::tfrt_stub::FallbackTensor(std::move(tensor));
942 }
943
944 class TestAllocator : public tensorflow::AllocatorWrapper {
945 public:
TestAllocator()946 TestAllocator() : tensorflow::AllocatorWrapper(tensorflow::cpu_allocator()) {}
947
AllocateRaw(size_t alignment,size_t num_bytes)948 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
949 std::printf("Using TestAllocator\n");
950 fflush(stdout);
951 return wrapped()->AllocateRaw(alignment, num_bytes);
952 }
953
AllocateRaw(size_t alignment,size_t num_bytes,const AllocationAttributes & allocation_attr)954 void* AllocateRaw(size_t alignment, size_t num_bytes,
955 const AllocationAttributes& allocation_attr) override {
956 std::printf("Using TestAllocator\n");
957 fflush(stdout);
958 return wrapped()->AllocateRaw(alignment, num_bytes, allocation_attr);
959 }
960 };
961
GetTestAllocator()962 tensorflow::Allocator* GetTestAllocator() {
963 static auto* const test_allocator = new TestAllocator;
964 return test_allocator;
965 }
966
RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry * registry)967 void RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry* registry) {
968 registry->AddKernel("tfrt_fallback_async.const_tensor_proto",
969 TFRT_KERNEL(ConstTensorProto));
970 registry->AddKernel("tfrt_fallback_async.executeop", FallbackAsyncExecuteOp);
971 registry->AddKernel("tfrt_fallback_async.executeop.seq",
972 FallbackAsyncExecuteOpSeq);
973 registry->AddKernel("tfrt_fallback_async.executeop.allocator",
974 FallbackAsyncExecuteOpWithAllocator);
975 registry->AddKernel("tfrt_fallback_async.executeop.seq.allocator",
976 FallbackAsyncExecuteOpSeqWithAllocator);
977 registry->AddKernel("tfrt_fallback_async.copy_if_small",
978 TFRT_KERNEL(FallbackCopyTensorIfSmall));
979 registry->AddKernel("tfrt_fallback_async.createop",
980 TFRT_KERNEL(KernelFallbackCreateOp));
981 registry->AddKernel("tfrt_fallback_async.set_resource",
982 TFRT_KERNEL(FallbackSetResource));
983 registry->AddKernel("tfrt_fallback_async.get_resource",
984 TFRT_KERNEL(FallbackGetResource));
985
986 // TODO(chky): Move test kernels to test-only library.
987 registry->AddKernel("tfrt_fallback_async.get_test_allocator",
988 TFRT_KERNEL(GetTestAllocator));
989 }
990
991 TFRT_STATIC_KERNEL_REGISTRATION(RegisterKernelFallbackCompatKernels);
992
993 } // namespace
994 } // namespace tfd
995 } // namespace tensorflow
996