1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h"
16
17 #include "absl/container/inlined_vector.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h"
21 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
22
23 namespace tensorflow {
24 namespace tfd {
25
SetUpParams(const tfrt_stub::OpKernelRunner & runner,const KernelFallbackCompatRequestState & fallback_request_state,tensorflow::Device * device,tfrt_stub::OpKernelRunState & run_state)26 void SetUpParams(const tfrt_stub::OpKernelRunner& runner,
27 const KernelFallbackCompatRequestState& fallback_request_state,
28 tensorflow::Device* device,
29 tfrt_stub::OpKernelRunState& run_state) {
30 auto& params = run_state.params;
31 params.inputs = run_state.input_tf_tensor_values;
32 params.device = device;
33 params.op_kernel = runner.op_kernel();
34 // Still use original device's resource_manager.
35 params.resource_manager = runner.resource_manager();
36 params.input_alloc_attrs = runner.input_alloc_attrs();
37 params.output_attr_array = runner.output_alloc_attrs().data();
38 params.step_container = fallback_request_state.step_container();
39 // Following two parameters are used to support executing tf.data via
40 // fallback.
41 params.function_library = runner.function_library_runtime();
42 params.runner = fallback_request_state.runner();
43 params.collective_executor = fallback_request_state.collective_executor();
44 params.rendezvous = fallback_request_state.rendezvous();
45 params.session_metadata = &fallback_request_state.session_metadata();
46 params.cancellation_manager = fallback_request_state.cancellation_manager();
47 }
48
49 // Return the device to be used for the fallback kernel execution. The device is
50 // guaranteed to be alive during the graph execution.
GetDeviceFromFallbackState(const KernelFallbackCompatRequestState & fallback_request_state,const tfrt_stub::OpKernelRunner & kernel_runner)51 tensorflow::Device* GetDeviceFromFallbackState(
52 const KernelFallbackCompatRequestState& fallback_request_state,
53 const tfrt_stub::OpKernelRunner& kernel_runner) {
54 // Return the user-specified the custom device instead, (eg. to use a custom
55 // thread pool).
56 //
57 // The device handling is similar to TF1 code in the below link:
58 // http://cs/?q=f:common_runtime%2Fexecutor.cc:692%20package:piper&rcl=351575626
59 auto* device = kernel_runner.device();
60 if (auto* custom_device = fallback_request_state.custom_device(device)) {
61 return custom_device;
62 }
63 return device;
64 }
65
66 } // namespace tfd
67 } // namespace tensorflow
68