xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/execute_node.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/execute_node.h"
16 
17 #include "tensorflow/core/lib/core/errors.h"
18 
19 namespace tensorflow {
20 
21 #if !defined(IS_MOBILE_PLATFORM)
IsRemote(EagerContext * ctx,Device * input_device,TensorHandle * handle)22 bool ExecuteNodeArgs::IsRemote(EagerContext* ctx, Device* input_device,
23                                TensorHandle* handle) {
24   uint64 context_view_id = ctx->GetContextViewId();
25   if (handle->Type() == TensorHandle::REMOTE ||
26       handle->HasRemoteMirror(input_device, context_view_id)) {
27     if (!has_remote_inputs_) {
28       has_remote_inputs_ = true;
29     }
30     return true;
31   }
32   return false;
33 }
34 #endif  // IS_MOBILE_PLATFORM
35 
InitPackedHandle(const int index,EagerContext * ctx,Device * input_device,TensorHandle * packed_handle)36 Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
37                                          Device* input_device,
38                                          TensorHandle* packed_handle) {
39   int num_handles = packed_handle->NumPackedHandles();
40   packed_args_.emplace(index, gtl::InlinedVector<TensorValue, 4>(num_handles));
41   TensorValue* packed_arg_flat = &(packed_args_[index][0]);
42   for (int i = 0; i < num_handles; ++i) {
43     TensorHandle* h = nullptr;
44     TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
45     // We have validated that h->device() is not a CustomDevice when
46     // constructing a pack TensorHandle.
47     const Status status = h->TensorValue(h->device(), &packed_arg_flat[i]);
48     if (!status.ok()) {
49 #if !defined(IS_MOBILE_PLATFORM)
50       if (IsRemote(ctx, input_device, h)) {
51         continue;
52       }
53 #endif  // IS_MOBILE_PLATFORM
54       if (h->Type() == TensorHandle::PACKED) {
55         return errors::InvalidArgument(
56             "Nested packed handles are not supported");
57       }
58       return status;
59     }
60   }
61   return OkStatus();
62 }
63 
Init(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & op_inputs,const core::RefCountPtr<KernelAndDevice> & kernel)64 Status ExecuteNodeArgs::Init(
65     EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& op_inputs,
66     const core::RefCountPtr<KernelAndDevice>& kernel) {
67   // If there are multiple references to a TensorHandle in 'op_inputs' we must
68   // increment the reference count of the corresponding Tensor or risk it being
69   // overwritten during kernel execution. The reference count is incremented
70   // below when we insert a copy of the Tensor into protected_tensors, and will
71   // be decremented once execution is complete.
72   const int n_inputs = op_inputs.size();
73   if (n_inputs > 0) {
74     TensorHandle* const* op_inputs_flat = &op_inputs[0];
75     TensorValue* tensor_args_flat = &tensor_args_[0];
76     for (int i = 0; i < n_inputs; ++i) {
77       TensorHandle* in = op_inputs_flat[i];
78       Device* d = kernel->InputDevice(i);
79       Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
80       if (!s.ok()) {
81 #if !defined(IS_MOBILE_PLATFORM)
82         if (IsRemote(ctx, d, in)) {
83           continue;
84         }
85 #endif
86         if (in->Type() != TensorHandle::PACKED) {
87           return s;
88         }
89         if (!has_packed_inputs_) {
90           has_packed_inputs_ = true;
91         }
92         TF_RETURN_IF_ERROR(InitPackedHandle(i, ctx, d, in));
93       }
94     }
95   }
96 
97 #if !defined(IS_MOBILE_PLATFORM)
98   if (has_remote_inputs_) {
99     const bool is_function = kernel->IsFunction();
100     serialize_remote_handle_ =
101         [ctx, &op_inputs, is_function](
102             const FunctionArgIndex& index,
103             eager::RemoteTensorHandle* handle) -> Status {
104       TensorHandle* h = op_inputs[index.index];
105       if (op_inputs[index.index]->Type() == TensorHandle::PACKED) {
106         TF_RETURN_IF_ERROR(
107             op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
108       }
109       Device* device = h->device();
110       // For a multi-device function, a remote RunComponentFunction request is
111       // not sent through StreamingEnqueueAsync. It could arrive at a remote
112       // worker before a remote execution request which produces an input of the
113       // component function. So we wait until the remote input is ready before
114       // serializing it.
115       const bool wait_util_ready = is_function;
116       return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
117           h, wait_util_ready, handle, device, device->name());
118     };
119   }
120 #endif  // !IS_MOBILE_PLATFORM
121   return OkStatus();
122 }
123 
GetLocalArg(const FunctionArgIndex & index,Tensor * val) const124 Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index,
125                                     Tensor* val) const {
126   Status s = EagerKernelArgs::GetLocalArg(index, val);
127   if (s.ok()) {
128     return OkStatus();
129   }
130   if (packed_args_.contains(index.index)) {
131     Tensor* arg = packed_args_.at(index.index).at(index.sub_index).tensor;
132     if (arg) {
133       *val = *arg;
134       return OkStatus();
135     } else {
136       return errors::NotFound("Argument (", index.index, ",", index.sub_index,
137                               ") has no local tensor.");
138     }
139   } else {
140     return s;
141   }
142 }
143 
144 }  // namespace tensorflow
145