xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/training_op_helpers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/kernels/dense_update_functor.h"
23 #include "tensorflow/core/kernels/variable_ops.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 
26 namespace tensorflow {
27 
28 // Must be called before performing a sparse operation on a variable. Ensures
29 // that no concurrent dense operations can happen while holding the variable's
30 // lock.
31 template <typename Device, typename T>
EnsureSparseVariableAccess(OpKernelContext * ctx,Var * var)32 Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33   if (var->copy_on_read_mode.load()) {
34     return OkStatus();
35   }
36   mutex_lock ml(*var->mu());
37   // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38   // also happen if there are no concurrent reads of the variable and
39   // copy-on-read mode is false.
40   if (var->tensor()->RefCountIsOne()) {
41     var->copy_on_read_mode.store(true);
42     return OkStatus();
43   }
44   Tensor tmp;
45   if (std::is_same<T, Variant>::value) {
46     AllocatorAttributes attr;
47     attr.set_on_host(true);
48     TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
49                                           var->tensor()->shape(), &tmp, attr));
50 
51     const auto elements_in = var->tensor()->flat<Variant>();
52     auto elements_out = tmp.flat<Variant>();
53     for (int64_t i = 0; i < elements_in.size(); ++i) {
54       elements_out(i) = elements_in(i);
55     }
56   } else {
57     AllocatorAttributes attr;
58     attr.set_gpu_compatible(true);
59     attr.set_nic_compatible(true);
60     TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
61                                           var->tensor()->shape(), &tmp, attr));
62     functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63     copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
64                  const_cast<const Tensor*>(var->tensor())->flat<T>());
65   }
66   *var->tensor() = tmp;
67   var->copy_on_read_mode.store(true);
68   return OkStatus();
69 }
70 
71 // Utility structure that releases a sequence of borrowed mutexes when it is
72 // deleted.
73 struct VariableInputLockHolder {
74  public:
VariableInputLockHolderVariableInputLockHolder75   VariableInputLockHolder(
76       std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77       std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78       : vars_(std::move(vars)),
79         locks_(std::move(locks)),
80         shared_locks_(std::move(shared_locks)) {}
81 
VariableInputLockHolderVariableInputLockHolder82   VariableInputLockHolder(VariableInputLockHolder&& other)
83       : vars_(std::move(other.vars_)),
84         locks_(std::move(other.locks_)),
85         shared_locks_(std::move(other.shared_locks_)) {}
86 
~VariableInputLockHolderVariableInputLockHolder87   ~VariableInputLockHolder() {
88     // Release the locks before unreffing the Vars, because each lock
89     // is potentially borrowed from a Var in vars_.
90     locks_.reset();
91     for (Var* var : vars_) {
92       var->Unref();
93     }
94   }
95 
96  private:
97   std::vector<Var*> vars_;
98   // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99   // because a `std::vector<mutex_lock>` is not movable on all platforms.
100   std::unique_ptr<std::vector<mutex_lock>> locks_;
101   std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102 };
103 
104 // Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105 //
106 // If `input` corresponds to a `DT_RESOURCE`-type variable input,
107 // `*maybe_resource` will be updated to contain the underlying resource, and the
108 // caller will be responsible for calling `Unref()` on that resource.
109 template <typename Device, typename T>
GetTrainingVariableMutex(OpKernelContext * ctx,int input,bool sparse,Var ** maybe_resource)110 mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111                                 Var** maybe_resource) {
112   *maybe_resource = nullptr;
113   if (ctx->input_dtype(input) == DT_RESOURCE) {
114     if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115       if (sparse) {
116         EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117             .IgnoreError();
118       }
119       return (*maybe_resource)->mu();
120     } else {
121       ctx->CtxFailureWithWarning(
122           errors::Internal("Invalid variable reference."));
123       return nullptr;
124     }
125   }
126   return ctx->input_ref_mutex(input);
127 }
128 
129 // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130 // in address order to mitigate deadlock.  Returns a structure that, when
131 // deleted, will release the acquired mutexes. Safe to pass duplicates - will
132 // only lock each distinct mutex once. If sparse is true will ensure the
133 // variable gets switched to copy-on-read mode before trying to acquire the
134 // locks. If do_lock is false, returns immediately for reference variables. For
135 // resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136 // is false, exclusive lock otherwise.  Note that this silently doesn't lock
137 // mutexes for invalid variable references; in all usages this is followed by
138 // GetInputTensor which will signal a failure.
139 template <typename Device, typename T>
MaybeLockVariableInputMutexesInOrder(OpKernelContext * ctx,bool do_lock,bool sparse,const std::vector<int> & input_ids)140 VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141     OpKernelContext* ctx, bool do_lock, bool sparse,
142     const std::vector<int>& input_ids) {
143   bool any_resource = false;
144   for (auto i : input_ids) {
145     if (ctx->input_dtype(i) == DT_RESOURCE) {
146       any_resource = true;
147       break;
148     }
149   }
150   if (!do_lock && !any_resource) {
151     return VariableInputLockHolder({}, {}, {});
152   }
153   std::vector<Var*> vars;
154   std::vector<mutex*> mutexes;
155   std::vector<int> acquire_order;
156   for (auto input : input_ids) {
157     Var* var;
158     mutex* mutex =
159         GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160     if (var) vars.push_back(var);
161     // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162     if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163       acquire_order.push_back(mutexes.size());
164       mutexes.push_back(mutex);
165     }
166   }
167   std::sort(acquire_order.begin(), acquire_order.end(),
168             [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169 
170   auto locks = std::make_unique<std::vector<mutex_lock>>();
171   auto shared_locks = std::make_unique<std::vector<tf_shared_lock>>();
172   locks->reserve(acquire_order.size());
173 
174   for (auto acquire : acquire_order) {
175     mutex* mu = mutexes[acquire];
176     if (mu != nullptr) {
177       if (!sparse || do_lock) {
178         locks->emplace_back(*mu);
179       } else {
180         shared_locks->emplace_back(*mu);
181       }
182     }
183   }
184   return VariableInputLockHolder(std::move(vars), std::move(locks),
185                                  std::move(shared_locks));
186 }
187 
188 void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
189                                      int output);
190 
191 // This is for use with ResourceVariables to ensure *tensor has a
192 // reference count of 1 before you update it.
193 // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
194 template <typename Device, typename T>
PrepareToUpdateVariable(OpKernelContext * ctx,Tensor * tensor,bool copy_on_read_mode)195 Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
196                                bool copy_on_read_mode) {
197   if (copy_on_read_mode || !tensor->RefCountIsOne()) {
198     // Tensor's buffer is in use by some read, so we need to copy before
199     // updating.
200     Tensor tmp;
201     if (std::is_same<T, Variant>::value) {
202       AllocatorAttributes attr;
203       attr.set_on_host(true);
204       TF_RETURN_IF_ERROR(
205           ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
206 
207       const auto elements_in = tensor->flat<Variant>();
208       auto elements_out = tmp.flat<Variant>();
209       for (int64_t i = 0; i < elements_in.size(); ++i) {
210         elements_out(i) = elements_in(i);
211       }
212     } else {
213       AllocatorAttributes attr;
214       attr.set_gpu_compatible(true);
215       attr.set_nic_compatible(true);
216       TF_RETURN_IF_ERROR(
217           ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
218       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
219       copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
220                    const_cast<const Tensor*>(tensor)->flat<T>());
221     }
222     *tensor = tmp;
223   }
224   return OkStatus();
225 }
226 
227 // This gives you `*out`, a tensor you can update, corresponding to a variable
228 // passed as input index `input`.  This handles the differences between
229 // reference and resource variables. For reference variables we can just grab
230 // the tensor, grabbing the lock if lock_held is False.
231 //
232 // For resource variables we, if sparse is true, ensure it's in copy-on-read
233 // mode, and then, regardless of the value of sparse, ensure its refcount is 1
234 // (by potentially copying its contents). In this case lock_held is ignored.
235 template <typename Device, typename T>
GetInputTensorFromVariable(OpKernelContext * ctx,int input,bool lock_held,bool sparse,Tensor * out)236 Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
237                                   bool lock_held, bool sparse, Tensor* out) {
238   if (ctx->input_dtype(input) == DT_RESOURCE) {
239     core::RefCountPtr<Var> var;
240     TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
241     if (sparse) {
242       TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
243       *out = *var->tensor();
244       return OkStatus();
245     }
246     TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
247         ctx, var->tensor(), var->copy_on_read_mode.load()));
248     *out = *var->tensor();
249     return OkStatus();
250   }
251   *out = ctx->mutable_input(input, lock_held);
252   return OkStatus();
253 }
254 
255 }  // end namespace tensorflow
256 
257 #endif  // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
258