1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
16 
17 #include <optional>
18 #include <string>
19 
20 #include "llvm/ADT/StringRef.h"
21 #include "tensorflow/core/common_runtime/eager/context.h"
22 #include "tensorflow/core/framework/logging.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/threadpool_interface.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
31 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h"
32 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h"
33 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
34 #include "tensorflow/core/runtime_fallback/runtime/op_logger.h"
35 #include "tensorflow/core/runtime_fallback/util/attr_util.h"
36 #include "tensorflow/core/tfrt/fallback/cost_recorder.h"
37 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
38 #include "tensorflow/core/tfrt/utils/error_util.h"
39 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
40 #include "tensorflow/core/tfrt/utils/tensor_util.h"
41 #include "tfrt/core_runtime/execute_op_impl.h"  // from @tf_runtime
42 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
43 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
44 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
45 #include "tfrt/host_context/chain.h"  // from @tf_runtime
46 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
47 #include "tfrt/host_context/kernel_registry.h"  // from @tf_runtime
48 #include "tfrt/host_context/sync_kernel_frame.h"  // from @tf_runtime
49 #include "tfrt/support/error_util.h"  // from @tf_runtime
50 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
51 #include "tfrt/support/pointer_util.h"  // from @tf_runtime
52 #include "tfrt/support/string_util.h"  // from @tf_runtime
53 #include "tfrt/tensor/tensor.h"  // from @tf_runtime
54 
55 namespace tensorflow {
56 namespace tfd {
57 const char kOpKernelRunnerCacheResourceName[] =
58     "OpKernelRunnerCacheResourceName";
59 
60 namespace {
61 
62 using ::tensorflow::tfrt_stub::OpKernelRunner;
63 using ::tensorflow::tfrt_stub::OpKernelRunnerTable;
64 using ::tensorflow::tfrt_stub::OpKernelRunState;
65 using ::tfrt::AsyncValue;
66 using ::tfrt::AsyncValueRef;
67 using ::tfrt::Chain;
68 using ::tfrt::OpAttrsRef;
69 using ::tfrt::RCReference;
70 using ::tfrt::string_view;
71 
72 constexpr char kOpKernelRunnerTableResourceName[] =
73     "OpKernelRunnerTableResourceName";
74 
75 constexpr char kFallbackResourceArray[] = "FallbackResourceArray";
76 
KernelFallbackEmitError(const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState * fallback_request_state,tfrt::string_view op_name,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const tensorflow::Status & status)77 void KernelFallbackEmitError(
78     const tfrt::ExecutionContext& exec_ctx,
79     const KernelFallbackCompatRequestState* fallback_request_state,
80     tfrt::string_view op_name, tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
81     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
82     const tensorflow::Status& status) {
83   // Set all results to error, with the correct TFRT error code according to the
84   // error propagated from runtime fallback execution.
85   auto model_info =
86       fallback_request_state == nullptr
87           ? "(missing model info) "
88           : tfrt::StrCat(
89                 fallback_request_state->session_metadata().name(), " (",
90                 fallback_request_state->session_metadata().version(), ") ");
91   auto error = EmitErrorAsync(
92       exec_ctx,
93       tfrt::StrCat(model_info, "error running kernel fallback kernel ", op_name,
94                    ": ", status.error_message()),
95       tfrt::ConvertTfErrorCodeToTfrtErrorCode(status));
96   std::fill(results.begin(), results.end(), error);
97   if (op_chain) *op_chain = std::move(error);
98 }
99 
GetDefaultRunner()100 std::function<void(std::function<void()>)>* GetDefaultRunner() {
101   static auto* const default_runner =
102       new std::function<void(std::function<void()>)>(
103           [](const std::function<void()>& f) { f(); });
104   return default_runner;
105 }
106 
107 }  // namespace
108 
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,const tensorflow::DeviceMgr * device_manager,const tensorflow::ProcessFunctionLibraryRuntime * pflr,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<SessionMetadata> & model_metadata,std::function<void (std::function<void ()>)> * runner)109 Status SetUpKernelFallbackCompatRequestContext(
110     tfrt::RequestContextBuilder* builder,
111     const tensorflow::DeviceMgr* device_manager,
112     const tensorflow::ProcessFunctionLibraryRuntime* pflr,
113     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
114     const absl::optional<SessionMetadata>& model_metadata,
115     std::function<void(std::function<void()>)>* runner) {
116   DCHECK(builder);
117   DCHECK(device_manager);
118   DCHECK(pflr);
119 
120   auto* runner_table =
121       builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
122           kOpKernelRunnerTableResourceName);
123 
124   auto* resource_array =
125       builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
126           kFallbackResourceArray);
127 
128   builder->context_data().emplace<KernelFallbackCompatRequestState>(
129       runner ? runner : GetDefaultRunner(), device_manager, builder->id(),
130       runner_table, resource_array, user_intra_op_threadpool, model_metadata,
131       pflr);
132 
133   return OkStatus();
134 }
135 
SetUpKernelFallbackCompatRequestContext(tfrt::RequestContextBuilder * builder,OpKernelRunnerTable * runner_table,tensorflow::EagerContext * eager_context,tensorflow::thread::ThreadPoolInterface * user_intra_op_threadpool,const absl::optional<SessionMetadata> & model_metadata)136 Status SetUpKernelFallbackCompatRequestContext(
137     tfrt::RequestContextBuilder* builder, OpKernelRunnerTable* runner_table,
138     tensorflow::EagerContext* eager_context,
139     tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool,
140     const absl::optional<SessionMetadata>& model_metadata) {
141   auto* resource_array =
142       builder->resource_context()->GetOrCreateResource<FallbackResourceArray>(
143           kFallbackResourceArray);
144 
145   if (runner_table == nullptr)
146     runner_table =
147         builder->resource_context()->GetOrCreateResource<OpKernelRunnerTable>(
148             kOpKernelRunnerTableResourceName);
149 
150   auto step_id = builder->id();
151 
152   auto& fallback_request_state =
153       builder->context_data().emplace<KernelFallbackCompatRequestState>(
154           GetDefaultRunner(), eager_context->local_device_mgr(), step_id,
155           tfrt::OwnedOrUnownedPtr<ScopedStepContainer>{
156               eager_context->StepContainer()},
157           eager_context->GetCollectiveExecutorHandle(),
158           tensorflow::core::RefCountPtr<tensorflow::Rendezvous>(
159               eager_context->RendezvousCreator()(step_id)),
160           runner_table, resource_array, user_intra_op_threadpool,
161           model_metadata, eager_context->pflr());
162 
163   fallback_request_state.set_log_device_placement(
164       eager_context->LogDevicePlacement());
165 
166   return OkStatus();
167 }
168 
169 static llvm::Expected<gtl::InlinedVector<tensorflow::Tensor, 4>>
ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor * > arguments,const tfrt::ExecutionContext & exec_ctx)170 ConvertInputTensors(llvm::ArrayRef<tfrt::Tensor*> arguments,
171                     const tfrt::ExecutionContext& exec_ctx) {
172   gtl::InlinedVector<tensorflow::Tensor, 4> input_tf_tensors;
173   input_tf_tensors.reserve(arguments.size());
174   for (tfrt::Tensor* argument : arguments) {
175     auto expected_tf_tensor = TFRTTensorToTFTensor(*argument, exec_ctx.host());
176     if (!expected_tf_tensor) {
177       return tfrt::MakeStringError(
178           tfrt::StrCat(expected_tf_tensor.takeError()));
179     }
180     input_tf_tensors.push_back(std::move(expected_tf_tensor.get()));
181   }
182 
183   return input_tf_tensors;
184 }
185 
ValidateInputTypes(tfrt::string_view op_name,const gtl::InlinedVector<tensorflow::Tensor,4> & input_tf_tensors,const DataTypeVector & input_types)186 static Status ValidateInputTypes(
187     tfrt::string_view op_name,
188     const gtl::InlinedVector<tensorflow::Tensor, 4>& input_tf_tensors,
189     const DataTypeVector& input_types) {
190   const size_t n_inputs = input_tf_tensors.size();
191 
192   if (input_types.size() != n_inputs) {
193     return tensorflow::errors::InvalidArgument("expected ", input_types.size(),
194                                                " inputs, got ", n_inputs);
195   }
196 
197   for (size_t i = 0; i < n_inputs; ++i) {
198     if (input_tf_tensors[i].dtype() != input_types[i]) {
199       return tensorflow::errors::InvalidArgument(
200           "cannot compute ", op_name.str(), " as input #", i, "(zero-based)",
201           " was expected to be a ", DataTypeString(input_types[i]),
202           " tensor but is a ", DataTypeString(input_tf_tensors[i].dtype()),
203           " tensor");
204     }
205   }
206 
207   return OkStatus();
208 }
209 
210 namespace {
211 
212 // Keep states needed by kernel execution in a thread local storage to avoid
213 // repeated reallocation and destruction of them.
GetThreadLocalOpKernelRunState()214 OpKernelRunState& GetThreadLocalOpKernelRunState() {
215   thread_local OpKernelRunState run_state;
216   return run_state;
217 }
218 
219 }  // namespace
220 
221 // Execute a tensorflow::OpKernel Asynchronously. `kernel_runner` and
222 // `input_tf_tensors` are expected to be alive during the call to this function.
223 // Set result AsyncValues in `results` and return a Chain that indicates the
224 // execution completion of error otherwise.
225 template <typename TensorType>
KernelFallbackExecuteCompatAsyncInternal(const tfrt::ExecutionContext & exec_ctx,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)226 static void KernelFallbackExecuteCompatAsyncInternal(
227     const tfrt::ExecutionContext& exec_ctx, OpKernelRunState* run_state,
228     const OpKernelRunner& kernel_runner,
229     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
230     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
231   auto chain =
232       tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>(exec_ctx.host());
233   if (op_chain) *op_chain = chain.CopyRef();
234 
235   // Allocate unconstructed result tensors and set them in the output `results`.
236   llvm::SmallVector<AsyncValueRef<TensorType>, 4> result_refs;
237   result_refs.reserve(results.size());
238   for (auto& result : results) {
239     result_refs.emplace_back(
240         tfrt::MakeUnconstructedAsyncValueRef<TensorType>(exec_ctx.host()));
241     result = result_refs.back().CopyRef();
242   }
243 
244   struct AsyncState {
245     explicit AsyncState(const OpKernelRunState& rs, int num_outputs)
246         : run_state(rs.input_tf_tensor_values, rs.params),
247           context(&run_state.params, num_outputs) {}
248 
249     OpKernelRunState run_state;
250     OpKernelContext context;
251 
252     tfrt::AsyncValueRef<tfrt::Chain> chain;
253     llvm::SmallVector<tfrt::AsyncValueRef<TensorType>, 4> result_refs;
254   };
255 
256   DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
257   auto async_state = std::make_shared<AsyncState>(*run_state, results.size());
258   async_state->chain = chain.CopyRef();
259   async_state->result_refs = std::move(result_refs);
260 
261   auto* context_ptr = &async_state->context;
262 
263   auto done_callback = [async_state = std::move(async_state), exec_ctx]() {
264     auto& context = async_state->context;
265 
266     if (!context.status().ok()) {
267       auto diag = tfrt::EmitError(
268           exec_ctx,
269           {tfrt::StrCat("error running kernel fallback kernel ",
270                         context.op_kernel().name(), ": ",
271                         context.status().error_message())},
272           tfrt::ConvertTfErrorCodeToTfrtErrorCode(context.status()));
273       for (auto& result : async_state->result_refs) result.SetError(diag);
274       async_state->chain.SetError(diag);
275       return;
276     }
277 
278     // Set payload and mark async values available in TFRT's thread.
279     tfrt::EnqueueWork(exec_ctx, [async_state = std::move(async_state)]() {
280       auto& context = async_state->context;
281       for (int i = 0; i < context.num_outputs(); ++i) {
282         async_state->result_refs[i].emplace(
283             std::move(*context.mutable_output(i)));
284       }
285       async_state->chain.emplace();
286     });
287   };
288 
289   kernel_runner.RunAsync(context_ptr, std::move(done_callback));
290 }
291 
292 // Execute a tensorflow::OpKernel synchronously. `kernel_runner` and
293 // `input_tf_tensors` are expected to be alive during the call to this function.
294 // Set result AsyncValues in `results` and return OK status on successfully
295 // finishing the execution. TensorType is expected to be convert-constructible
296 // from tensorflow::Tensor.
297 template <typename TensorType>
KernelFallbackExecuteCompatSyncInternal(const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState * fallback_request_state,OpKernelRunState * run_state,const OpKernelRunner & kernel_runner,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results)298 static void KernelFallbackExecuteCompatSyncInternal(
299     const tfrt::ExecutionContext& exec_ctx,
300     const KernelFallbackCompatRequestState* fallback_request_state,
301     OpKernelRunState* run_state, const OpKernelRunner& kernel_runner,
302     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
303     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results) {
304   DCHECK_EQ(results.size(), kernel_runner.op_kernel()->num_outputs());
305   OpKernelContext context(&run_state->params, results.size());
306   kernel_runner.Run(&context);
307 
308   if (!context.status().ok()) {
309     KernelFallbackEmitError(exec_ctx, fallback_request_state,
310                             kernel_runner.op_kernel()->name(), op_chain,
311                             results, context.status());
312     return;
313   }
314 
315   for (int i = 0; i < context.num_outputs(); ++i) {
316     results[i] = tfrt::MakeAvailableAsyncValueRef<TensorType>(
317         std::move(*context.mutable_output(i)));
318   }
319 
320   if (op_chain) *op_chain = tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
321 }
322 
PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef & attrs)323 static std::string PrintTfrtOpAttrsToString(const tfrt::OpAttrsRef& attrs) {
324   std::string str;
325   llvm::raw_string_ostream ss(str);
326   attrs.Print(ss);
327   ss.flush();
328   return str;
329 }
330 
KernelFallbackExecuteCompatCoreRuntimeDispatch(const tfrt::ExecutionContext & exec_ctx,tfrt::string_view op_name,tfrt::string_view device_name,llvm::ArrayRef<tfrt::Tensor * > arguments,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,const KernelFallbackCompatRequestState & fallback_request_state,const OpKernelRunner & op_kernel_runner)331 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackExecuteCompatCoreRuntimeDispatch(
332     const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name,
333     tfrt::string_view device_name, llvm::ArrayRef<tfrt::Tensor*> arguments,
334     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
335     const KernelFallbackCompatRequestState& fallback_request_state,
336     const OpKernelRunner& op_kernel_runner) {
337   auto op_chain = tfrt::GetReadyChain();
338   tensorflow::Status status;
339 
340   auto expected_input_tf_tensors = ConvertInputTensors(arguments, exec_ctx);
341   if (!expected_input_tf_tensors) {
342     status = tensorflow::errors::Internal(
343         tfrt::StrCat(expected_input_tf_tensors.takeError()));
344     KernelFallbackEmitError(exec_ctx, &fallback_request_state, op_name,
345                             &op_chain, results, status);
346     return op_chain;
347   }
348 
349   auto& run_state = GetThreadLocalOpKernelRunState();
350   auto clean_up_inputs =
351       gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
352 
353   auto& input_tf_tensors = run_state.input_tf_tensors;
354   input_tf_tensors = std::move(expected_input_tf_tensors.get());
355 
356   // Check if input tensor dtypes are valid.
357   status = ValidateInputTypes(op_name, input_tf_tensors,
358                               op_kernel_runner.op_kernel()->input_types());
359 
360   // TODO(b/176997538): Skip checking dtypes for tf._BatchFunctionFallback op
361   // due to b/176997538. Remove the skipping once the SavedModel lowering
362   // problem is fixed.
363   if (!status.ok() && !op_name.equals("_BatchFunctionFallback")) {
364     KernelFallbackEmitError(exec_ctx, &fallback_request_state, op_name,
365                             &op_chain, results, status);
366     return op_chain;
367   }
368 
369   auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
370   input_tf_tensor_values.resize(input_tf_tensors.size());
371   for (int i = 0; i < input_tf_tensors.size(); ++i) {
372     input_tf_tensor_values[i].tensor = &input_tf_tensors[i];
373   }
374 
375   auto* device =
376       GetDeviceFromFallbackState(fallback_request_state, op_kernel_runner);
377 
378   SetUpParams(op_kernel_runner, fallback_request_state, device, run_state);
379 
380   if (op_kernel_runner.IsAsync()) {
381     KernelFallbackExecuteCompatAsyncInternal<KernelFallbackTensor>(
382         exec_ctx, &run_state, op_kernel_runner, &op_chain, results);
383   } else {
384     KernelFallbackExecuteCompatSyncInternal<KernelFallbackTensor>(
385         exec_ctx, &fallback_request_state, &run_state, op_kernel_runner,
386         &op_chain, results);
387   }
388 
389   return op_chain;
390 }
391 
StripTfPrefix(tfrt::string_view op_name)392 static absl::string_view StripTfPrefix(tfrt::string_view op_name) {
393   return absl::StripPrefix(ToAbslStringView(op_name), "tf.");
394 }
395 
396 // Generate metadata for an execution op event
GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue * > args,const tfrt::ExecutionContext & exec_ctx,const OpKernelRunner & kernel_runner)397 std::string GetTracingMetadata(llvm::ArrayRef<tfrt::AsyncValue*> args,
398                                const tfrt::ExecutionContext& exec_ctx,
399                                const OpKernelRunner& kernel_runner) {
400   auto request_id = exec_ctx.request_ctx()->id();
401   // Get Long Name
402   auto debug_info = exec_ctx.location().GetDebugInfo();
403   auto long_name = debug_info.has_value() ? debug_info.getValue().info : "";
404 
405   if (!profiler::TfOpDetailsEnabled()) {
406     return profiler::TraceMeEncode(
407         {{"id", request_id}, {"long_name", ToAbslStringView(long_name)}});
408   }
409 
410   // Get Input Tensors
411   std::string input_string;
412   llvm::raw_string_ostream input_string_stream(input_string);
413 
414   for (size_t i = 0; i < args.size(); ++i) {
415     const auto& tensor = args[i]->get<Tensor>();
416     input_string_stream << DataTypeString(tensor.dtype())
417                         << tensor.shape().DebugString();
418     input_string_stream << ";";
419   }
420 
421   // Get Attributes
422   std::string attr_string;
423   llvm::raw_string_ostream attr_string_stream(attr_string);
424 
425   for (const auto& entry : kernel_runner.op_kernel()->def().attr()) {
426     attr_string_stream << entry.first << ": {" << entry.second.DebugString();
427     if (!attr_string.empty() && attr_string[attr_string.size() - 1] == '\n') {
428       attr_string[attr_string.size() - 1] = '}';
429     }
430     attr_string_stream << ";\n";
431   }
432 
433   return profiler::TraceMeEncode({
434       {"id", request_id},
435       {"long_name", ToAbslStringView(long_name)},
436       {"inputs", input_string},
437       {"attributes", attr_string},
438   });
439 }
440 
441 namespace {
442 
443 class FallbackKernelAttributeFrame {
444  public:
FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame * frame)445   explicit FallbackKernelAttributeFrame(tfrt::AsyncKernelFrame* frame)
446       : frame_(frame) {
447     DCHECK(frame_);
448   }
449 
device() const450   tfrt::StringAttr device() const {
451     return tfrt::StringAttr(frame_->GetAttribute(kDeviceAttrPosition));
452   }
453 
op_attr() const454   tfrt::AggregateAttr op_attr() const {
455     return tfrt::AggregateAttr(frame_->GetAttribute(kOpAttrPosition));
456   }
457 
op_func_attr() const458   tfrt::AggregateAttr op_func_attr() const {
459     return tfrt::AggregateAttr(frame_->GetAttribute(kOpFuncAttrPosition));
460   }
461 
op_key() const462   tfrt::I64Attr op_key() const {
463     return tfrt::I64Attr(frame_->GetAttribute(kOpKeyAttrPosition));
464   }
465 
op_name() const466   tfrt::StringAttr op_name() const {
467     return tfrt::StringAttr(frame_->GetAttribute(kOpNameAttrPosition));
468   }
469 
470  private:
471   static constexpr int kDeviceAttrPosition = 0;
472   static constexpr int kOpAttrPosition = 1;
473   static constexpr int kOpFuncAttrPosition = 2;
474   static constexpr int kOpKeyAttrPosition = 3;
475   static constexpr int kOpNameAttrPosition = 4;
476 
477   tfrt::AsyncKernelFrame* frame_ = nullptr;
478 };
479 
480 // The BEF kernel for kernel fallback compat mode. The arguments and results are
481 // expected to tensorflow::tfrt_stub::FallbackTensor.
KernelFallbackExecuteOpInternal(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx,const KernelFallbackCompatRequestState & fallback_request_state,const OpKernelRunner & kernel_runner,bool is_async,tensorflow::Device * device)482 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOpInternal(
483     llvm::ArrayRef<tfrt::AsyncValue*> args,
484     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
485     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
486     const FallbackKernelAttributeFrame& frame,
487     const tfrt::ExecutionContext& exec_ctx,
488     const KernelFallbackCompatRequestState& fallback_request_state,
489     const OpKernelRunner& kernel_runner, bool is_async,
490     tensorflow::Device* device) {
491   tensorflow::profiler::TraceMe trace_me([&]() -> std::string {
492     if (kernel_runner.op_kernel()) {
493       return tensorflow::profiler::TraceMeOp(
494           kernel_runner.op_kernel()->name_view(),
495           kernel_runner.op_kernel()->type_string_view());
496     }
497     return std::string(ToAbslStringView(frame.op_name().GetValue()));
498   });
499 
500   trace_me.AppendMetadata(
501       [&]() { return GetTracingMetadata(args, exec_ctx, kernel_runner); });
502 
503   if (fallback_request_state.log_device_placement() || VLOG_IS_ON(1)) {
504     string msg =
505         strings::StrCat("Executing op ", frame.op_name().GetValue().str(),
506                         " in device ", frame.device().GetValue().str());
507     if (!logging::LogToListeners(msg)) {
508       LOG(INFO) << msg;
509     }
510   }
511 
512   auto& run_state = GetThreadLocalOpKernelRunState();
513   auto clean_up_inputs =
514       gtl::MakeCleanup([&]() { run_state.input_tf_tensors.clear(); });
515 
516   // Prepare the input tensors.
517   auto& input_tf_tensors = run_state.input_tf_tensors;
518   auto& input_tf_tensor_values = run_state.input_tf_tensor_values;
519   DCHECK(input_tf_tensors.empty());
520   input_tf_tensor_values.resize(args.size());
521   for (int i = 0; i < args.size(); ++i) {
522     auto* arg = args[i];
523     auto& fallback_tensor = arg->get<tensorflow::tfrt_stub::FallbackTensor>();
524     // If the argument is immutable or unique, we can just keep the reference
525     // without copying that invovles expensive atomic reference counting. And if
526     // the argument is unique but mutable, then tensorflow optimizations like
527     // buffer forwarding can be utilized. Otherwise, we conservatively copy the
528     // tensor.
529     if (!fallback_tensor.is_immutable() && !arg->IsUnique()) {
530       input_tf_tensors.push_back(fallback_tensor.tensor());
531     }
532     input_tf_tensor_values[i].tensor = &fallback_tensor.tensor();
533   }
534 
535   SetUpParams(kernel_runner, fallback_request_state, device, run_state);
536 
537   bool is_cost_measurement_enabled =
538       exec_ctx.request_ctx()->IsCostMeasurementEnabled();
539   auto run_start_time =
540       is_cost_measurement_enabled ? Env::Default()->NowMicros() : 0;
541   if (is_async) {
542     KernelFallbackExecuteCompatAsyncInternal<
543         tensorflow::tfrt_stub::FallbackTensor>(
544         exec_ctx, &run_state, kernel_runner, op_chain, results);
545   } else {
546     KernelFallbackExecuteCompatSyncInternal<
547         tensorflow::tfrt_stub::FallbackTensor>(
548         exec_ctx, &fallback_request_state, &run_state, kernel_runner, op_chain,
549         results);
550   }
551   if (is_cost_measurement_enabled) {
552     op_chain->AndThen([run_start_time, exec_ctx, frame] {
553       // Adds 1 to make sure it's a positive integer.
554       auto execution_time = Env::Default()->NowMicros() - run_start_time + 1;
555       // Adds op_key as a suffix to distinguish the same operation with
556       // different shape.
557       exec_ctx.host()
558           ->GetOrCreateSharedContext<tensorflow::tfrt_stub::CostRecorder>()
559           .RecordCost(frame.op_key().GetValue(), execution_time);
560     });
561   }
562 }
563 
KernelFallbackExecuteOp(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const FallbackKernelAttributeFrame & frame,const tfrt::ExecutionContext & exec_ctx)564 TF_ATTRIBUTE_ALWAYS_INLINE static void KernelFallbackExecuteOp(
565     llvm::ArrayRef<tfrt::AsyncValue*> args,
566     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
567     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
568     const FallbackKernelAttributeFrame& frame,
569     const tfrt::ExecutionContext& exec_ctx) {
570   const auto* fallback_request_state =
571       exec_ctx.request_ctx()
572           ->GetDataIfExists<KernelFallbackCompatRequestState>();
573   if (!fallback_request_state) {
574     KernelFallbackEmitError(
575         exec_ctx, /*fallback_request_state=*/nullptr,
576         frame.op_name().GetValue(), op_chain, results,
577         tensorflow::errors::NotFound(
578             "KernelFallbackCompatRequestState not found in RequestContext."));
579     return;
580   }
581 
582   auto* runner_table = fallback_request_state->runner_table();
583   DCHECK(runner_table);
584 
585   auto* kernel_runner = runner_table->Get(frame.op_key().GetValue());
586   DCHECK(kernel_runner);
587   DCHECK_EQ(kernel_runner->op_kernel()->name(),
588             StripTfPrefix(frame.op_name().GetValue()));
589 
590   auto* device =
591       GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner);
592 
593   KernelFallbackExecuteOpInternal(args, results, op_chain, frame, exec_ctx,
594                                   *fallback_request_state, *kernel_runner,
595                                   kernel_runner->IsAsync(), device);
596 }
597 
598 // The BEF kernel for creating tensorflow::OpKernel to be used in kernel
599 // fallback compat mode.
KernelFallbackCreateOp(const tfrt::Chain & in_ch,tfrt::StringAttr device,tfrt::I64Attr num_args,tfrt::AggregateAttr op_attr_array,tfrt::AggregateAttr op_func_attr_array,tfrt::I64Attr op_key,tfrt::StringAttr op_name_attr,const tfrt::ExecutionContext & exec_ctx)600 tfrt::AsyncValueRef<tfrt::Chain> KernelFallbackCreateOp(
601     const tfrt::Chain& in_ch, tfrt::StringAttr device, tfrt::I64Attr num_args,
602     tfrt::AggregateAttr op_attr_array, tfrt::AggregateAttr op_func_attr_array,
603     tfrt::I64Attr op_key, tfrt::StringAttr op_name_attr,
604     const tfrt::ExecutionContext& exec_ctx) {
605   const auto* fallback_request_state =
606       exec_ctx.request_ctx()
607           ->GetDataIfExists<KernelFallbackCompatRequestState>();
608   if (!fallback_request_state) {
609     return tfrt::EmitErrorAsync(
610         exec_ctx,
611         "KernelFallbackCompatRequestState not found in RequestContext.");
612   }
613 
614   auto* runner_table = fallback_request_state->runner_table();
615   DCHECK(runner_table);
616 
617   auto attr_builder = [op_attr_array, op_func_attr_array](
618                           tensorflow::AttrValueMap* attr_value_map) {
619     return SetUpAttrValueMap(op_attr_array, op_func_attr_array, attr_value_map);
620   };
621 
622   auto op_name = StripTfPrefix(op_name_attr.GetValue());
623 
624   auto statusor_runner = OpKernelRunner::Create(
625       op_name, ToAbslStringView(device.GetValue()), num_args.GetValue(),
626       attr_builder, fallback_request_state->device_manager(),
627       fallback_request_state->process_function_library_runtime());
628   if (!statusor_runner.ok())
629     return tfrt::EmitErrorAsync(
630         exec_ctx, statusor_runner.status().error_message(),
631         tfrt::ConvertTfErrorCodeToTfrtErrorCode(statusor_runner.status()));
632 
633   if (!runner_table->Insert(op_key.GetValue(),
634                             std::move(statusor_runner).ValueOrDie())) {
635     return tfrt::EmitErrorAsync(
636         exec_ctx,
637         absl::StrCat("KernelFallbackCreateOp: OpKernelRunner already exists: ",
638                      op_name_attr.GetValue().str()));
639   }
640 
641   return tfrt::MakeAvailableAsyncValueRef<tfrt::Chain>();
642 }
643 
644 // FallbackSetResource is the fallback kernel that sets the tensor value in the
645 // fallback's resource array.
FallbackSetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::StringAttr device,tfrt::I64Attr index_attr,const tfrt::ExecutionContext & exec_ctx)646 llvm::Expected<tfrt::Chain> FallbackSetResource(
647     tfrt::Argument<tfrt::Chain> in_ch,
648     tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
649     tfrt::StringAttr device, tfrt::I64Attr index_attr,
650     const tfrt::ExecutionContext& exec_ctx) {
651   const auto* fallback_request_state =
652       exec_ctx.request_ctx()
653           ->GetDataIfExists<KernelFallbackCompatRequestState>();
654   if (!fallback_request_state) {
655     return tfrt::MakeStringError(
656         "KernelFallbackCompatRequestState not found in RequestContext.");
657   }
658 
659   auto* resource_array = fallback_request_state->resource_array();
660   DCHECK(resource_array);
661 
662   int64_t index = index_attr.GetValue();
663 
664   // Setting the resource tensor to be immutable, so that we don't need
665   // reference counting on it and that it cannot be buffer-forwarded.
666   resource_array->SetResource(
667       index,
668       tensorflow::tfrt_stub::ImmutableTensor::Create(arg.get().tensor()));
669 
670   return tfrt::Chain();
671 }
672 
673 // FallbackGetResource is the fallback kernel that retrieves the tensor value in
674 // the fallback's resource array.
FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,tfrt::Result<tfrt::Chain> out_ch,tfrt::RemainingResults results,tfrt::StringAttr device,tfrt::ArrayAttr indices_attr,const tfrt::ExecutionContext & exec_ctx)675 void FallbackGetResource(tfrt::Argument<tfrt::Chain> in_ch,
676                          tfrt::Result<tfrt::Chain> out_ch,
677                          tfrt::RemainingResults results,
678                          tfrt::StringAttr device, tfrt::ArrayAttr indices_attr,
679                          const tfrt::ExecutionContext& exec_ctx) {
680   tensorflow::profiler::TraceMe trace_me("tfrt_fallback_async.get_resource");
681   trace_me.AppendMetadata([request_id = exec_ctx.request_ctx()->id()]() {
682     return tensorflow::profiler::TraceMeEncode({{"id", request_id}});
683   });
684 
685   const auto* fallback_request_state =
686       exec_ctx.request_ctx()
687           ->GetDataIfExists<KernelFallbackCompatRequestState>();
688   if (!fallback_request_state) {
689     tfrt::RCReference<tfrt::AsyncValue> error = tfrt::EmitErrorAsync(
690         exec_ctx,
691         "KernelFallbackCompatRequestState not found in RequestContext.");
692     out_ch.Set(std::move(error));
693     return;
694   }
695 
696   auto* resource_array = fallback_request_state->resource_array();
697   DCHECK(resource_array);
698 
699   llvm::ArrayRef<int64_t> indices = indices_attr.GetValue<int64_t>();
700 
701   for (int i = 0; i < indices.size(); ++i) {
702     results[i] = tfrt::FormRef(resource_array->GetResource(indices[i]));
703   }
704 
705   out_ch.Set(in_ch);
706 }
707 
708 // The implementation of tfrt_fallback_async.executeop kernel. It executes a
709 // non-side-effecting TF op with the name of `op_name` in fallback. All relevant
710 // TF attributes are passed in `op_attr_array`.
FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame * frame)711 void FallbackAsyncExecuteOp(tfrt::AsyncKernelFrame* frame) {
712   FallbackKernelAttributeFrame attr_frame(frame);
713 #ifndef NDEBUG
714   frame->GetExecutionContext()
715       .host()
716       ->GetOrCreateSharedContext<OpLogger>()
717       .LogOp(attr_frame.op_name().GetValue());
718 #endif
719   // Create op_chain only when cost measurement is enabled. It is used for
720   // measuring async op's actual latency.
721   if (frame->GetExecutionContext().request_ctx()->IsCostMeasurementEnabled()) {
722     auto op_chain = tfrt::MakeUnconstructedAsyncValueRef<tfrt::Chain>();
723     KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
724                             &op_chain, attr_frame,
725                             frame->GetExecutionContext());
726   } else {
727     KernelFallbackExecuteOp(frame->GetArguments(), frame->GetResults(),
728                             /*op_chain=*/nullptr, attr_frame,
729                             frame->GetExecutionContext());
730   }
731 }
732 
733 // The implementation of tfrt_fallback_async.executeop.seq kernel. It executes a
734 // side-effecting TF op with the name of `op_name` in fallback. All relevant
735 // TF attributes are passed in `op_attr_array`. `in_op_chain` and `out_op_chain`
736 // are used for side-effect visibility.
FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame * frame)737 void FallbackAsyncExecuteOpSeq(tfrt::AsyncKernelFrame* frame) {
738   auto all_args = frame->GetArguments();
739   DCHECK_GT(all_args.size(), 0);
740   tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(all_args[0]));
741   llvm::ArrayRef<tfrt::AsyncValue*> args = all_args.drop_front();
742 
743   auto all_results = frame->GetResults();
744   DCHECK_GT(all_results.size(), 0);
745   auto& out_op_chain = all_results[0];
746   llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results =
747       all_results.drop_front();
748 
749   KernelFallbackExecuteOp(args, results, &op_chain,
750                           FallbackKernelAttributeFrame(frame),
751                           frame->GetExecutionContext());
752   out_op_chain = std::move(op_chain);
753 }
754 
755 class DeviceWithCustomAllocator : public tensorflow::Device {
756  public:
DeviceWithCustomAllocator(tensorflow::Device * device,tensorflow::Allocator * allocator)757   DeviceWithCustomAllocator(tensorflow::Device* device,
758                             tensorflow::Allocator* allocator)
759       : Device(device->env(), device->attributes()),
760         device_(device),
761         allocator_(allocator) {
762     DCHECK(device_);
763     DCHECK(allocator_);
764   }
765 
GetAllocator(AllocatorAttributes attr)766   Allocator* GetAllocator(AllocatorAttributes attr) override {
767     return allocator_;
768   }
769 
UnderlyingDevice() const770   const DeviceBase* UnderlyingDevice() const override {
771     return device_->UnderlyingDevice();
772   }
UnderlyingDevice()773   DeviceBase* UnderlyingDevice() override {
774     return device_->UnderlyingDevice();
775   }
776 
tensorflow_cpu_worker_threads() const777   const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
778     return device_->tensorflow_cpu_worker_threads();
779   }
780 
GetScopedAllocator(AllocatorAttributes attr,int64_t step_id)781   Allocator* GetScopedAllocator(AllocatorAttributes attr,
782                                 int64_t step_id) override {
783     return device_->GetScopedAllocator(attr, step_id);
784   }
785 
GetScopedAllocatorMgr() const786   ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
787     return device_->GetScopedAllocatorMgr();
788   }
789 
eigen_cpu_device()790   const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
791     return device_->eigen_cpu_device();
792   }
793 
tensorflow_device_thread_pool()794   thread::ThreadPool* tensorflow_device_thread_pool() override {
795     return device_->tensorflow_device_thread_pool();
796   }
797 
has_eigen_cpu_device() const798   bool has_eigen_cpu_device() const override {
799     return device_->has_eigen_cpu_device();
800   }
801 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)802   Status MakeTensorFromProto(const TensorProto& tensor_proto,
803                              const AllocatorAttributes alloc_attrs,
804                              Tensor* tensor) override {
805     return device_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
806   }
807 
CopyTensorInSameDevice(const Tensor * input_tensor,Tensor * output_tensor,const DeviceContext * device_context,StatusCallback done)808   void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor,
809                               const DeviceContext* device_context,
810                               StatusCallback done) override {
811     device_->CopyTensorInSameDevice(input_tensor, output_tensor, device_context,
812                                     std::move(done));
813   }
814 
Sync()815   Status Sync() override { return device_->Sync(); }
816 
817   // Returns the resource manager associated w/ this device.
resource_manager()818   ResourceMgr* resource_manager() override {
819     return device_->resource_manager();
820   }
821 
822  private:
823   tensorflow::Device* device_ = nullptr;
824   tensorflow::Allocator* allocator_ = nullptr;
825 };
826 
KernelFallbackExecuteOpCustomAllocatorInternal(llvm::ArrayRef<tfrt::AsyncValue * > args,llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,tfrt::AsyncValueRef<tfrt::Chain> * op_chain,const tfrt::ExecutionContext & exec_ctx,const FallbackKernelAttributeFrame & attr_frame)827 void KernelFallbackExecuteOpCustomAllocatorInternal(
828     llvm::ArrayRef<tfrt::AsyncValue*> args,
829     llvm::MutableArrayRef<tfrt::RCReference<tfrt::AsyncValue>> results,
830     tfrt::AsyncValueRef<tfrt::Chain>* op_chain,
831     const tfrt::ExecutionContext& exec_ctx,
832     const FallbackKernelAttributeFrame& attr_frame) {
833   const auto* fallback_request_state =
834       exec_ctx.request_ctx()
835           ->GetDataIfExists<KernelFallbackCompatRequestState>();
836   if (!fallback_request_state) {
837     KernelFallbackEmitError(
838         exec_ctx, /*fallback_request_state=*/nullptr,
839         attr_frame.op_name().GetValue(), op_chain, results,
840         tensorflow::errors::NotFound(
841             "KernelFallbackCompatRequestState not found in RequestContext."));
842     return;
843   }
844 
845   auto* runner_table = fallback_request_state->runner_table();
846   DCHECK(runner_table);
847 
848   auto* kernel_runner = runner_table->Get(attr_frame.op_key().GetValue());
849   DCHECK(kernel_runner);
850   DCHECK_EQ(kernel_runner->op_kernel()->name(),
851             StripTfPrefix(attr_frame.op_name().GetValue()));
852 
853   DCHECK_GT(args.size(), 0);
854   auto* allocator = args.front()->get<tensorflow::Allocator*>();
855   args = args.drop_front();
856 
857   DeviceWithCustomAllocator device_with_custom_allocator(
858       GetDeviceFromFallbackState(*fallback_request_state, *kernel_runner),
859       allocator);
860 
861   // Different from FallbackAsyncExecuteOp, async execution is not allowed due
862   // to the lifetime of the wrapper device cannot be extended.
863   //
864   // TODO(b/200575143): Consider allowing async execution and extending the
865   // lifetime of the wrapping device.
866   KernelFallbackExecuteOpInternal(args, results,
867                                   /*op_chain=*/op_chain, attr_frame, exec_ctx,
868                                   *fallback_request_state, *kernel_runner,
869                                   /*is_async=*/false,
870                                   &device_with_custom_allocator);
871 }
872 
FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame * frame)873 void FallbackAsyncExecuteOpWithAllocator(tfrt::AsyncKernelFrame* frame) {
874   auto args = frame->GetArguments();
875   auto results = frame->GetResults();
876   FallbackKernelAttributeFrame attr_frame(frame);
877   KernelFallbackExecuteOpCustomAllocatorInternal(
878       args, results, /*op_chain=*/nullptr, frame->GetExecutionContext(),
879       attr_frame);
880 }
881 
FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame * frame)882 void FallbackAsyncExecuteOpSeqWithAllocator(tfrt::AsyncKernelFrame* frame) {
883   auto args = frame->GetArguments();
884   DCHECK_GT(args.size(), 0);
885   tfrt::AsyncValueRef<tfrt::Chain> op_chain(tfrt::FormRef(args.front()));
886   args = args.drop_front();
887 
888   auto results = frame->GetResults();
889   DCHECK_GT(results.size(), 0);
890   auto& out_op_chain = results.front();
891   results = results.drop_front();
892 
893   FallbackKernelAttributeFrame attr_frame(frame);
894   KernelFallbackExecuteOpCustomAllocatorInternal(
895       args, results, &op_chain, frame->GetExecutionContext(), attr_frame);
896 
897   out_op_chain = std::move(op_chain);
898 }
899 
FallbackCopyTensorIfSmall(tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,tfrt::RemainingResults results)900 void FallbackCopyTensorIfSmall(
901     tfrt::Argument<tensorflow::tfrt_stub::FallbackTensor> arg,
902     tfrt::RemainingResults results) {
903   const auto& fallback_tensor = arg.get();
904   const auto& tensor = fallback_tensor.tensor();
905 
906   if (!fallback_tensor.is_immutable()) {
907     // Create a new TensorBuffer which contains a new atomic counter for each
908     // result, to avoid downstream threads contending the original atomic
909     // counter.
910     for (int i = 0; i < results.size(); ++i) {
911       auto immutable_tensor =
912           tensorflow::tfrt_stub::ImmutableTensor::Create(tensor);
913       results[i] = tfrt::MakeAvailableAsyncValueRef<
914           tensorflow::tfrt_stub::FallbackTensor>(
915           std::move(immutable_tensor.tensor()));
916     }
917   } else {
918     // For immutable tensors, we just need to copy the pointer. Note that we
919     // still create a new AsyncValue in this case, to avoid atomic contention on
920     // AsyncValue's refcount.
921     for (int i = 0; i < results.size(); ++i) {
922       results[i] = tfrt::MakeAvailableAsyncValueRef<
923           tensorflow::tfrt_stub::FallbackTensor>(fallback_tensor);
924     }
925   }
926 }
927 
ConstTensorProto(tfrt::StringAttr serialized_tensor_proto)928 llvm::Expected<tensorflow::tfrt_stub::FallbackTensor> ConstTensorProto(
929     tfrt::StringAttr serialized_tensor_proto) {
930   tensorflow::TensorProto tensor_proto;
931   if (!tensor_proto.ParseFromString(serialized_tensor_proto.GetValue().str())) {
932     return tfrt::MakeStringError("Failed to parse const tensor proto");
933   }
934 
935   tensorflow::Tensor tensor;
936   if (!tensor.FromProto(tensor_proto)) {
937     return tfrt::MakeStringError("Failed to create tensor from tensor proto: ",
938                                  tensor_proto.ShortDebugString());
939   }
940 
941   return tensorflow::tfrt_stub::FallbackTensor(std::move(tensor));
942 }
943 
944 class TestAllocator : public tensorflow::AllocatorWrapper {
945  public:
TestAllocator()946   TestAllocator() : tensorflow::AllocatorWrapper(tensorflow::cpu_allocator()) {}
947 
AllocateRaw(size_t alignment,size_t num_bytes)948   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
949     std::printf("Using TestAllocator\n");
950     fflush(stdout);
951     return wrapped()->AllocateRaw(alignment, num_bytes);
952   }
953 
AllocateRaw(size_t alignment,size_t num_bytes,const AllocationAttributes & allocation_attr)954   void* AllocateRaw(size_t alignment, size_t num_bytes,
955                     const AllocationAttributes& allocation_attr) override {
956     std::printf("Using TestAllocator\n");
957     fflush(stdout);
958     return wrapped()->AllocateRaw(alignment, num_bytes, allocation_attr);
959   }
960 };
961 
GetTestAllocator()962 tensorflow::Allocator* GetTestAllocator() {
963   static auto* const test_allocator = new TestAllocator;
964   return test_allocator;
965 }
966 
RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry * registry)967 void RegisterKernelFallbackCompatKernels(tfrt::KernelRegistry* registry) {
968   registry->AddKernel("tfrt_fallback_async.const_tensor_proto",
969                       TFRT_KERNEL(ConstTensorProto));
970   registry->AddKernel("tfrt_fallback_async.executeop", FallbackAsyncExecuteOp);
971   registry->AddKernel("tfrt_fallback_async.executeop.seq",
972                       FallbackAsyncExecuteOpSeq);
973   registry->AddKernel("tfrt_fallback_async.executeop.allocator",
974                       FallbackAsyncExecuteOpWithAllocator);
975   registry->AddKernel("tfrt_fallback_async.executeop.seq.allocator",
976                       FallbackAsyncExecuteOpSeqWithAllocator);
977   registry->AddKernel("tfrt_fallback_async.copy_if_small",
978                       TFRT_KERNEL(FallbackCopyTensorIfSmall));
979   registry->AddKernel("tfrt_fallback_async.createop",
980                       TFRT_KERNEL(KernelFallbackCreateOp));
981   registry->AddKernel("tfrt_fallback_async.set_resource",
982                       TFRT_KERNEL(FallbackSetResource));
983   registry->AddKernel("tfrt_fallback_async.get_resource",
984                       TFRT_KERNEL(FallbackGetResource));
985 
986   // TODO(chky): Move test kernels to test-only library.
987   registry->AddKernel("tfrt_fallback_async.get_test_allocator",
988                       TFRT_KERNEL(GetTestAllocator));
989 }
990 
991 TFRT_STATIC_KERNEL_REGISTRATION(RegisterKernelFallbackCompatKernels);
992 
993 }  // namespace
994 }  // namespace tfd
995 }  // namespace tensorflow
996