xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/resource_variable_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 //   - acquire the variable's mutex (in "shared" mode);
20 //   - make a (shallow) copy of the Tensor object, which increments
21 //     the reference count on the variable's TensorBuffer;
22 //   - release the variable's mutex;
23 //   - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 //   - acquire the variable's mutex (in "exclusive" mode);
26 //   - check the reference count of variable's TensorBuffer and
27 //     if it is >1, make a deep copy of the variable's Tensor;
28 //   - mutate the variable's Tensor;
29 //   - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 //   "shared" mode) for the duration of the whole read. This means
41 //   that as long as you only do sparse read operations no write will
42 //   see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 //   that they want to perform the write without locks held
45 //   (use_locking=false), we never copy even if the variable's
46 //   reference count is >1.
47 
48 #define EIGEN_USE_THREADS
49 
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
53 #include "tensorflow/core/platform/stream_executor.h"
54 #endif
55 
56 #include <memory>
57 #include <vector>
58 
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/resource_variable_ops.h"
72 #include "tensorflow/core/kernels/resource_variable_util.h"
73 #include "tensorflow/core/kernels/scatter_functor.h"
74 #include "tensorflow/core/kernels/training_op_helpers.h"
75 #include "tensorflow/core/kernels/variable_ops.h"
76 #include "tensorflow/core/lib/core/errors.h"
77 #include "tensorflow/core/lib/core/refcount.h"
78 #include "tensorflow/core/platform/casts.h"
79 #include "tensorflow/core/platform/mem.h"
80 #include "tensorflow/core/platform/mutex.h"
81 #include "tensorflow/core/platform/types.h"
82 #include "tensorflow/core/util/determinism.h"
83 #include "tensorflow/core/util/util.h"
84 
85 namespace tensorflow {
86 
87 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
88                         ResourceHandlesOp<Var>);
89 
ReadVariableOp(OpKernelConstruction * c)90 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
91   OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
92 }
93 
94 namespace {
95 
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)96 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
97   Tensor* output;
98   Notification n;
99   Status status;
100   AllocatorAttributes attr;
101   if (t->dtype() == DT_VARIANT) {
102     attr.set_on_host(true);
103   }
104   TF_RETURN_IF_ERROR(
105       ctx->allocate_output(output_idx, t->shape(), &output, attr));
106   if (t->dtype() == DT_VARIANT) {
107     output->flat<Variant>() = t->flat<Variant>();
108   } else if (ctx->op_device_context() != nullptr) {
109     // TODO(apassos): remove the down_cast by just returning Device* from
110     // OpKernelContext
111     Device* device = down_cast<Device*>(ctx->device());
112     ctx->op_device_context()->CopyTensorInSameDevice(
113         t, device, output, [&n, &status](const Status& s) {
114           status = s;
115           n.Notify();
116         });
117     n.WaitForNotification();
118     return status;
119   } else {
120     switch (t->dtype()) {
121 #define HANDLER(type)                       \
122   case DataTypeToEnum<type>::value:         \
123     output->flat<type>() = t->flat<type>(); \
124     break;
125       TF_CALL_ALL_TYPES(HANDLER);
126 #undef HANDLER
127       default:
128         return errors::Internal("Unsupported dtype", t->dtype());
129     }
130   }
131   return OkStatus();
132 }
133 
134 }  // namespace
135 
Compute(OpKernelContext * ctx)136 void ReadVariableOp::Compute(OpKernelContext* ctx) {
137   core::RefCountPtr<Var> variable;
138   const ResourceHandle& handle = HandleFromInput(ctx, 0);
139   const auto status = LookupResource(ctx, handle, &variable);
140   OP_REQUIRES(ctx, status.ok(),
141               errors::FailedPrecondition(
142                   "Could not find variable ", handle.name(), ". ",
143                   "This could mean that the variable has been deleted. ",
144                   "In TF1, it can also mean the variable is uninitialized. ",
145                   "Debug info: container=", handle.container(),
146                   ", status error message=", status.error_message()));
147 
148   tf_shared_lock ml(*variable->mu());
149   // We're acquiring a reference to the underlying buffer while
150   // holding a shared lock to guarantee ordering of reads and
151   // writes when in copy-on-write mode.
152   const Tensor* t = variable->tensor();
153   if (!variable->copy_on_read_mode.load()) {
154     OP_REQUIRES(
155         ctx, dtype_ == t->dtype(),
156         errors::InvalidArgument(
157             "Trying to read variable with wrong dtype. Expected ",
158             DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
159     ctx->set_output(0, *t);
160   } else {
161     OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
162   }
163 }
164 
ReadVariablesOp(OpKernelConstruction * c)165 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
166   int n;
167   OP_REQUIRES_OK(c, c->GetAttr("N", &n));
168   OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
169   OP_REQUIRES(c, n == dtypes_.size(),
170               errors::InvalidArgument(
171                   "Mismatched number of arguments to ReadVariablesOp (", n,
172                   " vs. ", dtypes_.size(), ")"));
173 }
174 
Compute(OpKernelContext * ctx)175 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
176   std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
177   std::vector<const ResourceHandle*> handles(dtypes_.size());
178   for (size_t i = 0; i < dtypes_.size(); ++i) {
179     handles[i] = &HandleFromInput(ctx, i);
180   }
181 
182   OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
183 
184   std::vector<string> uninitialized_vars;
185   for (int64_t i = 0; i < variables.size(); i++) {
186     if (variables[i] == nullptr) {
187       uninitialized_vars.push_back(handles[i]->name());
188     }
189   }
190 
191   OP_REQUIRES(ctx, uninitialized_vars.empty(),
192               errors::FailedPrecondition(
193                   "In ReadVariablesOp the following variables were "
194                   "found uninitialized: ",
195                   absl::StrJoin(uninitialized_vars, ", ")));
196 
197   for (size_t i = 0; i < dtypes_.size(); ++i) {
198     // We're acquiring a reference to the underlying buffer while
199     // holding a shared lock to guarantee ordering of reads and
200     // writes.
201     tf_shared_lock ml(*variables[i]->mu());
202     OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
203                 errors::InvalidArgument(
204                     "Trying to read variable ", handles[i]->name(),
205                     " from Container: ", handles[i]->container(),
206                     " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
207                     " got ", DataTypeString(variables[i]->tensor()->dtype())));
208     if (variables[i]->copy_on_read_mode.load()) {
209       OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
210     } else {
211       const Tensor& t = *variables[i]->tensor();
212       ctx->set_output(i, t);
213     }
214   }
215 }
216 
217 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
218                         ReadVariableOp);
219 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
220                         ReadVariablesOp);
221 
222 REGISTER_KERNEL_BUILDER(
223     Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
224     ReadVariableOp);
225 REGISTER_KERNEL_BUILDER(
226     Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
227     ReadVariablesOp);
228 
VarHandleOp(OpKernelConstruction * context)229 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
230   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
231   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
232 
233   OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
234   OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
235 
236   is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
237 
238   // Use const_tensor_ if the variable is non-anonymous.
239   if (!is_anonymous_) {
240     AllocatorAttributes attr;
241     attr.set_on_host(true);
242     OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
243                                                    &const_tensor_, attr));
244     const_tensor_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
245         context, container_, name_,
246         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
247   }
248 }
249 
Compute(OpKernelContext * ctx)250 void VarHandleOp::Compute(OpKernelContext* ctx) {
251   if (is_anonymous_) {
252     Var* resource = new Var(dtype_and_shape_.dtype);
253     ResourceMgr* mgr = ctx->resource_manager();
254     ResourceHandle handle = ResourceHandle::MakeRefCountingHandle<Var>(
255         resource, ctx->device()->name(),
256         std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
257         ctx->stack_trace());
258     // TODO(b/203901837): See if we can abolish all code paths that lookup
259     // anonymous variables and then stop publishing them to the manager.
260     OP_REQUIRES_OK(ctx, mgr->CreateUnowned<Var>(handle.container(),
261                                                 handle.name(), resource));
262 
263     AllocatorAttributes attr;
264     attr.set_on_host(true);
265     Tensor tensor;
266     OP_REQUIRES_OK(
267         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
268 
269     tensor.scalar<ResourceHandle>()() = std::move(handle);
270 
271     ctx->set_output(0, tensor);
272   } else {
273     ctx->set_output(0, const_tensor_);
274   }
275 }
276 
277 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
278 
279 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
280 #define REGISTER_GPU_KERNELS(type)                             \
281   namespace functor {                                          \
282   template <>                                                  \
283   void DenseUpdate<GPUDevice, type, ASSIGN>::operator()(       \
284       const GPUDevice& d, typename TTypes<type>::Flat lhs,     \
285       typename TTypes<type>::ConstFlat rhs);                   \
286   extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
287   }
288 
289 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
290 TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
291 TF_CALL_int64(REGISTER_GPU_KERNELS);
292 TF_CALL_variant(REGISTER_GPU_KERNELS);
293 TF_CALL_uint32(REGISTER_GPU_KERNELS);
294 #undef REGISTER_GPU_KERNELS
295 
296 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
297 
298 #define REGISTER_DEFAULT_KERNELS(type)                        \
299   REGISTER_KERNEL_BUILDER(Name("VarHandleOp")                 \
300                               .Device(DEVICE_DEFAULT)         \
301                               .HostMemory("resource")         \
302                               .TypeConstraint<type>("dtype"), \
303                           VarHandleOp)
304 TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
305 TF_CALL_bfloat16(REGISTER_DEFAULT_KERNELS);
306 TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
307 TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
308 TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
309 #undef REGISTER_DEFAULT_KERNELS
310 
311 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
312                             .Device(DEVICE_DEFAULT)
313                             .HostMemory("resources")
314                             .TypeConstraint("dtypes",
315                                             {DT_INT64, DT_COMPLEX64,
316                                              DT_COMPLEX128, DT_HALF, DT_FLOAT,
317                                              DT_DOUBLE, DT_BOOL, DT_VARIANT}),
318                         ResourceHandlesOp<Var>);
319 
320 REGISTER_KERNEL_BUILDER(
321     Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
322     VariableShapeOp<int32>);
323 REGISTER_KERNEL_BUILDER(Name("VariableShape")
324                             .Device(DEVICE_CPU)
325                             .TypeConstraint<int64_t>("out_type"),
326                         VariableShapeOp<int64_t>);
327 
328 REGISTER_KERNEL_BUILDER(Name("VariableShape")
329                             .Device(DEVICE_DEFAULT)
330                             .TypeConstraint<int32>("out_type")
331                             .HostMemory("output")
332                             .HostMemory("input"),
333                         VariableShapeOp<int32>);
334 REGISTER_KERNEL_BUILDER(Name("VariableShape")
335                             .Device(DEVICE_DEFAULT)
336                             .TypeConstraint<int64_t>("out_type")
337                             .HostMemory("output")
338                             .HostMemory("input"),
339                         VariableShapeOp<int64_t>);
340 
DestroyResourceOp(OpKernelConstruction * ctx)341 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
342     : OpKernel(ctx) {
343   OP_REQUIRES_OK(ctx,
344                  ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
345 }
346 
Compute(OpKernelContext * ctx)347 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
348   const ResourceHandle& p = HandleFromInput(ctx, 0);
349   Status status = DeleteResource(ctx, p);
350   if (ignore_lookup_error_ && errors::IsNotFound(status)) {
351     return;
352   }
353   OP_REQUIRES_OK(ctx, status);
354 }
355 
356 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
357                         DestroyResourceOp);
358 REGISTER_KERNEL_BUILDER(
359     Name("DestroyResourceOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
360     DestroyResourceOp);
361 
Compute(OpKernelContext * ctx)362 void DisableCopyOnReadOp::Compute(OpKernelContext* ctx) {
363   core::RefCountPtr<Var> variable;
364   const ResourceHandle& handle = HandleFromInput(ctx, 0);
365   const auto status = LookupResource(ctx, handle, &variable);
366   OP_REQUIRES(ctx, status.ok(),
367               errors::FailedPrecondition(
368                   "Could not find variable ", handle.name(), ". ",
369                   "This could mean that the variable has been deleted. ",
370                   "In TF1, it can also mean the variable is uninitialized. ",
371                   "Debug info: container=", handle.container(),
372                   ", status error message=", status.error_message()));
373   // If the variable is currently in copy-on-read mode, its refcount is 1
374   if (variable->copy_on_read_mode.load()) {
375     // Obtain an exclusive lock on the variable and change the access mode
376     mutex_lock ml(*variable->mu());
377     variable->copy_on_read_mode.store(false);
378   }
379 }
380 
381 REGISTER_KERNEL_BUILDER(Name("DisableCopyOnRead").Device(DEVICE_CPU),
382                         DisableCopyOnReadOp);
383 REGISTER_KERNEL_BUILDER(
384     Name("DisableCopyOnRead").Device(DEVICE_DEFAULT).HostMemory("resource"),
385     DisableCopyOnReadOp);
386 
387 template <typename Device, typename T>
388 class AssignVariableOp : public OpKernel {
389  public:
AssignVariableOp(OpKernelConstruction * c)390   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
391     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
392     if (!c->GetAttr("_grappler_relax_allocator_constraints",
393                     &relax_constraints_)
394              .ok()) {
395       relax_constraints_ = false;
396     }
397     if (c->HasAttr("validate_shape")) {
398       OP_REQUIRES_OK(c, c->GetAttr("validate_shape", &validate_shape_));
399     }
400   }
401 
Compute(OpKernelContext * context)402   void Compute(OpKernelContext* context) override {
403     OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
404                 errors::InvalidArgument(
405                     "Variable and value dtypes don't match; respectively, ",
406                     DataTypeString(dtype_), " and ",
407                     DataTypeString(context->input(1).dtype())));
408     core::RefCountPtr<Var> variable;
409     const Tensor& value = context->input(1);
410     // Note: every resource-variable-manipulating op assumes copy-on-write
411     // semantics, and creates a copy of the variable's Tensor if its refcount is
412     // bigger than 1 when we try to modify it. This means we never need to copy
413     // the original tensor for AssignVariableOp; even if there are other live
414     // users of it we know none can modify it so this is always safe (even in
415     // esoteric cases where the same tensor is used to initialize multiple
416     // variables or the tensor is a constant this is safe, as future writes will
417     // trigger copies).
418     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
419                                 context, HandleFromInput(context, 0), &variable,
420                                 [this, &value](Var** ptr) {
421                                   *ptr = new Var(dtype_);
422                                   *(*ptr)->tensor() = value;
423                                   (*ptr)->is_initialized = true;
424                                   return OkStatus();
425                                 }));
426     mutex_lock ml(*variable->mu());
427     // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
428     // check below is to allow an XLA specific situation wherein update can
429     // happen first by the AssignVariableOp,
430     // in which case the variable is still uninitialized.
431     // When using TF-XLA, this scenario is possible when the execution uses the
432     // 'fallback' path (which essentially invokes Tensorflow ops via
433     // partitioned_call).
434     OP_REQUIRES(context,
435                 (variable->tensor()->dtype() == DT_INVALID &&
436                  !variable->is_initialized) ||
437                     variable->tensor()->dtype() == dtype_,
438                 errors::InvalidArgument(
439                     "Trying to assign variable with wrong dtype. Expected ",
440                     DataTypeString(variable->tensor()->dtype()), " got ",
441                     DataTypeString(dtype_)));
442     if (validate_shape_) {
443       OP_REQUIRES(
444           context,
445           (!variable->is_initialized ||
446            variable->tensor()->shape().IsSameSize(value.shape())),
447           errors::InvalidArgument(
448               "Trying to assign to variable with tensor with wrong shape."
449               " Expected ",
450               variable->tensor()->shape().DebugString(), " got ",
451               value.shape().DebugString()));
452     }
453     if (variable->copy_on_read_mode.load()) {
454       AllocatorAttributes attr;
455       attr.set_gpu_compatible(true);
456       attr.set_nic_compatible(true);
457       OP_REQUIRES_OK(context,
458                      context->allocate_temp(value.dtype(), value.shape(),
459                                             variable->tensor(), attr));
460       functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
461       copy_functor(context->eigen_device<Device>(),
462                    variable->tensor()->flat<T>(), value.flat<T>());
463     } else {
464       *variable->tensor() = value;
465     }
466     variable->is_initialized = true;
467   }
468 
469  private:
470   DataType dtype_;
471   bool relax_constraints_;
472   bool validate_shape_ = false;
473 };
474 
475 template <typename Device>
476 class AssignVariableOp<Device, Variant> : public OpKernel {
477  public:
AssignVariableOp(OpKernelConstruction * c)478   explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
479     OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
480     OP_REQUIRES(c, dtype_ == DT_VARIANT,
481                 errors::Internal("Variant kernel called with dtype: ",
482                                  DataTypeString(dtype_)));
483   }
484 
Compute(OpKernelContext * context)485   void Compute(OpKernelContext* context) override {
486     const Tensor& value = context->input(1);
487     core::RefCountPtr<Var> variable;
488     OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
489                                 context, HandleFromInput(context, 0), &variable,
490                                 [](Var** ptr) {
491                                   // Created on host.
492                                   *ptr = new Var(DT_VARIANT);
493                                   return OkStatus();
494                                 }));
495 
496     // For purposes of forwarding DT_VARIANT, we want the least
497     // restrictive attr; we already know the input is on host.
498     AllocatorAttributes attr;
499 
500     // Copying is unnecessary if we are the last user of the value
501     // tensor, we can just adopt the input tensor's buffer instead.
502     // Note that Variant objects themselves always reside on host.
503     //
504     // We nevertheless want to signal to the runtime that the tensor
505     // should reside in memory of the associated device, as Variant
506     // tensors may be marked as sitting on either CPU or GPU.  This
507     // helps to elide one or more copies.
508     std::unique_ptr<Tensor> input_alias = context->forward_input(
509         1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
510         value.shape(),
511         DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
512         attr);
513 
514     mutex_lock ml(*variable->mu());
515     OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
516                 errors::InvalidArgument(
517                     "Trying to assign variable with wrong dtype. Expected ",
518                     DataTypeString(variable->tensor()->dtype()), " got ",
519                     DataTypeString(DT_VARIANT)));
520     variable->is_initialized = true;
521     *variable->tensor() = Tensor(DT_VARIANT, value.shape());
522 
523     if (input_alias) {
524       *variable->tensor() = *input_alias;
525       return;
526     }
527 
528     // Need to copy, but maybe we can re-use variable's buffer?
529     if (!variable->tensor()->RefCountIsOne() ||
530         !variable->tensor()->shape().IsSameSize(value.shape())) {
531       // Allocation of DT_VARIANT is always on host.
532       attr.set_on_host(true);
533       OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
534                                                      variable->tensor(), attr));
535     }
536 
537     const auto elements_in = value.flat<Variant>();
538     auto elements_out = variable->tensor()->flat<Variant>();
539     for (int64_t i = 0; i < elements_in.size(); ++i) {
540       elements_out(i) = elements_in(i);
541     }
542   }
543 
544  private:
545   DataType dtype_;
546 };
547 
548 #define REGISTER_KERNELS(type)                                \
549   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")            \
550                               .Device(DEVICE_CPU)             \
551                               .TypeConstraint<type>("dtype"), \
552                           AssignVariableOp<Eigen::ThreadPoolDevice, type>);
553 
554 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
555 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
556 #undef REGISTER_KERNELS
557 
558 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
559 #define REGISTER_GPU_KERNELS(type)                           \
560   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
561                               .Device(DEVICE_GPU)            \
562                               .TypeConstraint<type>("dtype") \
563                               .HostMemory("resource"),       \
564                           AssignVariableOp<GPUDevice, type>);
565 
566 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
567 TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
568 TF_CALL_int64(REGISTER_GPU_KERNELS);
569 TF_CALL_uint32(REGISTER_GPU_KERNELS);
570 #undef REGISTER_GPU_KERNELS
571 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
572 
573 #define REGISTER_KERNELS(type)                               \
574   REGISTER_KERNEL_BUILDER(Name("AssignVariableOp")           \
575                               .Device(DEVICE_DEFAULT)        \
576                               .TypeConstraint<type>("dtype") \
577                               .HostMemory("resource"),       \
578                           AssignVariableOp<CPUDevice, type>);
579 
580 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
581 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
582 #undef REGISTER_KERNELS
583 
584 template <typename Device, typename T, DenseUpdateType Op>
585 class AssignUpdateVariableOp : public OpKernel {
586  public:
AssignUpdateVariableOp(OpKernelConstruction * c)587   explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
588 
Compute(OpKernelContext * context)589   void Compute(OpKernelContext* context) override {
590     core::RefCountPtr<Var> variable;
591     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
592                                            &variable));
593 
594     const Tensor& value = context->input(1);
595     // TODO(apassos): We could possibly avoid the copy done by
596     // PrepareToUpdateVariable() for commutative operations like Op ==
597     // ADD if value's refcount was 1.
598     mutex_lock ml(*variable->mu());
599     Tensor* var_tensor = variable->tensor();
600     OP_REQUIRES_OK(context, ValidateAssignUpdateVariableOpShapes(
601                                 var_tensor->shape(), value.shape()));
602     OP_REQUIRES_OK(
603         context, PrepareToUpdateVariable<Device, T>(
604                      context, var_tensor, variable->copy_on_read_mode.load()));
605     functor::DenseUpdate<Device, T, Op> update_functor;
606     update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
607                    value.flat<T>());
608   }
609 };
610 
611 #define REGISTER_KERNELS(type)                                     \
612   REGISTER_KERNEL_BUILDER(                                         \
613       Name("AssignAddVariableOp")                                  \
614           .Device(DEVICE_CPU)                                      \
615           .TypeConstraint<type>("dtype"),                          \
616       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
617   REGISTER_KERNEL_BUILDER(                                         \
618       Name("AssignSubVariableOp")                                  \
619           .Device(DEVICE_CPU)                                      \
620           .TypeConstraint<type>("dtype"),                          \
621       AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
622 
623 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
624 #undef REGISTER_KERNELS
625 
626 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
627 #define REGISTER_GPU_KERNELS(type)                                       \
628   REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp")                    \
629                               .Device(DEVICE_GPU)                        \
630                               .HostMemory("resource")                    \
631                               .TypeConstraint<type>("dtype"),            \
632                           AssignUpdateVariableOp<GPUDevice, type, ADD>); \
633   REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp")                    \
634                               .Device(DEVICE_GPU)                        \
635                               .HostMemory("resource")                    \
636                               .TypeConstraint<type>("dtype"),            \
637                           AssignUpdateVariableOp<GPUDevice, type, SUB>);
638 
639 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
640 TF_CALL_int64(REGISTER_GPU_KERNELS);
641 #undef REGISTER_GPU_KERNELS
642 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
643 
644 class VarIsInitializedOp : public OpKernel {
645  public:
VarIsInitializedOp(OpKernelConstruction * c)646   explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
647 
Compute(OpKernelContext * context)648   void Compute(OpKernelContext* context) override {
649     Tensor* output = nullptr;
650     OP_REQUIRES_OK(context,
651                    context->allocate_output(0, TensorShape({}), &output));
652     auto output_tensor = output->tensor<bool, 0>();
653     core::RefCountPtr<Var> variable;
654     Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
655     if (!s.ok()) {
656       output_tensor() = false;
657       return;
658     }
659     mutex_lock ml(*variable->mu());
660     output_tensor() = variable->is_initialized;
661   }
662 };
663 
664 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
665                         VarIsInitializedOp);
666 
667 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
668                             .Device(DEVICE_DEFAULT)
669                             .HostMemory("resource")
670                             .HostMemory("is_initialized"),
671                         VarIsInitializedOp);
672 
673 template <typename Device, typename T, typename Index>
674 class ResourceGatherOp : public OpKernel {
675  public:
ResourceGatherOp(OpKernelConstruction * c)676   explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
677     OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
678   }
679 
Compute(OpKernelContext * c)680   void Compute(OpKernelContext* c) override {
681     core::RefCountPtr<Var> v;
682     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
683     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
684     // NOTE: We hold the lock for the whole gather operation instead
685     // of increasing the reference count of v->tensor() to avoid a
686     // situation where a write to the same variable will see a
687     // reference count greater than one and make a copy of the
688     // (potentially very large) tensor buffer.
689     tf_shared_lock ml(*v->mu());
690     const Tensor& params = *v->tensor();
691     const Tensor& indices = c->input(1);
692     OP_REQUIRES(
693         c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
694         errors::InvalidArgument("params must be at least 1 dimensional"));
695     OP_REQUIRES(
696         c, params.shape().dims() >= batch_dims_,
697         errors::InvalidArgument("params must have at least ", batch_dims_,
698                                 " (batch_dims) dimensions but it has shape ",
699                                 params.shape().DebugString()));
700 
701     // Check that we have enough index space
702     const int64_t N = indices.NumElements();
703     OP_REQUIRES(
704         c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
705         errors::InvalidArgument("params.shape[0] too large for ",
706                                 DataTypeString(DataTypeToEnum<Index>::v()),
707                                 " indexing: ", params.dim_size(0), " > ",
708                                 std::numeric_limits<Index>::max()));
709 
710     // The result shape is params.shape[:batch_dims] +
711     // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
712     TensorShape result_shape;
713     for (int i = 0; i < batch_dims_; ++i) {
714       result_shape.AddDim(params.dim_size(i));
715     }
716     for (int i = batch_dims_; i < indices.dims(); ++i) {
717       result_shape.AddDim(indices.dim_size(i));
718     }
719     for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
720       result_shape.AddDim(params.dim_size(i));
721     }
722 
723     Tensor* out = nullptr;
724     Tensor tmp;
725     if (params.dtype() == DT_VARIANT) {
726       tmp = Tensor(DT_VARIANT, result_shape);
727       c->set_output(0, tmp);
728       out = &tmp;
729     } else {
730       OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
731     }
732 
733     if (N > 0) {
734       Tensor tmp_indices;
735 
736       // Points to the original or updated (if batch_dims is set) indices.
737       const Tensor* op_indices = &indices;
738       if (batch_dims_ > 0) {
739         OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
740                                            &tmp_indices));
741         functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
742         copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
743                      indices.flat<Index>());
744 
745         AddBatchOffsets(c, &tmp_indices, params);
746         if (!c->status().ok()) return;
747         op_indices = &tmp_indices;
748       }
749 
750       int64_t gather_dim_size = 1;
751       for (int idx = 0; idx <= batch_dims_; ++idx) {
752         gather_dim_size *= params.dim_size(idx);
753       }
754       int64_t inner_size = 1;
755       for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
756         inner_size *= params.dim_size(i);
757       }
758       auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
759       const auto indices_flat = op_indices->flat<Index>();
760       auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
761 
762       functor::GatherFunctor<Device, T, Index> functor;
763       int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
764 
765       OP_REQUIRES(
766           c, bad_i < 0,
767           errors::InvalidArgument(
768               "indices", SliceDebugString(indices.shape(), bad_i), " = ",
769               indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
770     }
771   }
772 
773  private:
774   // Add the batch offset derived from params to each batch of indices.
775   // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
776   // If indexing into a params dimension of size 4, then the indices will become
777   // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(OpKernelContext * ctx,Tensor * indices,const Tensor & params)778   void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
779                        const Tensor& params) {
780     int64_t batch_size = 1;  // The size of all batch dimensions.
781     for (int idx = 0; idx < batch_dims_; ++idx) {
782       batch_size *= params.dim_size(idx);
783     }
784     OP_REQUIRES(
785         ctx, batch_size != 0,
786         errors::InvalidArgument(
787             "Inner size of indices would result in batch_size of 0 and a ",
788             "division by 0 in the implementation. This is illegal"));
789 
790     auto indices_flat = indices->flat<Index>();
791     int64_t const index_inner_size = indices->NumElements() / batch_size;
792     int64_t const batch_offset = params.dim_size(batch_dims_);
793     for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
794          ++batch_idx) {
795       for (int64_t idx = 0; idx < index_inner_size; ++idx) {
796         indices_flat(dest_idx++) += batch_offset * batch_idx;
797       }
798     }
799   }
800 
801   int32 batch_dims_ = 0;
802 };
803 
804 #define REGISTER_GATHER_FULL(dev, type, index_type)                    \
805   REGISTER_KERNEL_BUILDER(Name("ResourceGather")                       \
806                               .Device(DEVICE_##dev)                    \
807                               .HostMemory("resource")                  \
808                               .TypeConstraint<type>("dtype")           \
809                               .TypeConstraint<index_type>("Tindices"), \
810                           ResourceGatherOp<dev##Device, type, index_type>)
811 
812 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
813   REGISTER_GATHER_FULL(dev, type, int32);      \
814   REGISTER_GATHER_FULL(dev, type, int64_t)
815 
816 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
817 
818 // Registration of the CPU implementations.
819 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
820 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
821 
822 // Registers GPU kernels.
823 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
824 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
825 
826 TF_CALL_int64(REGISTER_GATHER_GPU);
827 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
828 
829 // Variant objects themselves sit on CPU, even if they contain data
830 // pointing to a device.
831 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
832                             .Device(DEVICE_DEFAULT)
833                             .HostMemory("resource")
834                             .HostMemory("indices")
835                             .TypeConstraint<Variant>("dtype")
836                             .TypeConstraint<int32>("Tindices"),
837                         ResourceGatherOp<CPUDevice, Variant, int32>)
838 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
839                             .Device(DEVICE_DEFAULT)
840                             .HostMemory("resource")
841                             .HostMemory("indices")
842                             .TypeConstraint<Variant>("dtype")
843                             .TypeConstraint<int64_t>("Tindices"),
844                         ResourceGatherOp<CPUDevice, Variant, int64>)
845 
846 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
847 
848 #undef REGISTER_GATHER_CPU
849 #undef REGISTER_GATHER_GPU
850 #undef REGISTER_GATHER_ALL_INDICES
851 #undef REGISTER_GATHER_FULL
852 
853 template <typename Device, typename T, typename Index>
854 class ResourceGatherNdOp : public OpKernel {
855  public:
ResourceGatherNdOp(OpKernelConstruction * c)856   explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
857 
Compute(OpKernelContext * c)858   void Compute(OpKernelContext* c) override {
859     core::RefCountPtr<Var> v;
860     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
861     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
862     // NOTE: We hold the lock for the whole gather operation instead
863     // of increasing the reference count of v->tensor() to avoid a
864     // situation where a write to the same variable will see a
865     // reference count greater than one and make a copy of the
866     // (potentially very large) tensor buffer.
867     tf_shared_lock ml(*v->mu());
868     const Tensor& params = *v->tensor();
869     const Tensor& indices = c->input(1);
870 
871     Tensor out;
872     OP_REQUIRES_OK(
873         c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
874     c->set_output(0, out);
875   }
876 };
877 
878 #define REGISTER_GATHER_ND_FULL(dev, type, index_type)                 \
879   REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd")                     \
880                               .Device(DEVICE_##dev)                    \
881                               .HostMemory("resource")                  \
882                               .TypeConstraint<type>("dtype")           \
883                               .TypeConstraint<index_type>("Tindices"), \
884                           ResourceGatherNdOp<dev##Device, type, index_type>)
885 
886 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
887   REGISTER_GATHER_ND_FULL(dev, type, int32);      \
888   REGISTER_GATHER_ND_FULL(dev, type, int64_t)
889 
890 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
891 
892 // Registration of the CPU implementations.
893 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
894 
895 // Registers GPU kernels.
896 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
897 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
898 
899 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
900 
901 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
902 
903 #undef REGISTER_GATHER_ND_CPU
904 #undef REGISTER_GATHER_ND_GPU
905 #undef REGISTER_GATHER_ND_ALL_INDICES
906 #undef REGISTER_GATHER_ND_FULL
907 
908 namespace {
909 
910 template <typename Device>
isCPUDevice()911 bool isCPUDevice() {
912   return false;
913 }
914 
915 template <>
isCPUDevice()916 bool isCPUDevice<CPUDevice>() {
917   return true;
918 }
919 
920 template <typename T>
ValidateInput(const Tensor & updates)921 bool ValidateInput(const Tensor& updates) {
922   const auto updates_flat = updates.flat<T>();
923   for (int i = 0; i < updates.NumElements(); ++i) {
924     if (updates_flat(i) == T{}) return false;
925   }
926   return true;
927 }
928 
929 template <>
ValidateInput(const Tensor & updates)930 bool ValidateInput<Variant>(const Tensor& updates) {
931   return true;
932 }
933 
934 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
935 Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
936                  const Tensor& updates, Index num_indices);
937 
938 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
939 
940 template <typename T>
CopyTensorToHost(OpKernelContext * c,const Tensor & device_tensor,Tensor * host_tensor)941 Status CopyTensorToHost(OpKernelContext* c, const Tensor& device_tensor,
942                         Tensor* host_tensor) {
943   AllocatorAttributes alloc_attr;
944   alloc_attr.set_on_host(true);
945   alloc_attr.set_gpu_compatible(true);
946   auto stream = c->op_device_context()->stream();
947   TF_RETURN_IF_ERROR(c->allocate_temp(
948       device_tensor.dtype(), device_tensor.shape(), host_tensor, alloc_attr));
949   se::DeviceMemoryBase device_ptr(
950       const_cast<Tensor&>(device_tensor).flat<T>().data(),
951       device_tensor.flat<T>().size() * sizeof(T));
952   stream->ThenMemcpy(host_tensor->flat<T>().data(), device_ptr,
953                      device_tensor.NumElements() * sizeof(T));
954   if (!stream) {
955     return errors::Internal("Failed to copy indices to host");
956   }
957   return OkStatus();
958 }
959 
960 // Copies inputs to the CPU, runs DoScatter on the CPU, then copies output
961 // back to GPU. This is useful because the CPU implementation is deterministic
962 // and the GPU implementation is not. Tensor inputs to this function must be on
963 // the GPU.
964 template <typename T, typename Index, scatter_op::UpdateOp Op>
DoScatterOnCpu(OpKernelContext * c,Tensor * params,const Tensor & indices,const Tensor & updates,Index num_indices)965 Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices,
966                       const Tensor& updates, Index num_indices) {
967   auto stream = c->op_device_context()->stream();
968 
969   Tensor host_indices;
970   TF_RETURN_IF_ERROR(CopyTensorToHost<Index>(c, indices, &host_indices));
971   Tensor host_updates;
972   TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, updates, &host_updates));
973   Tensor host_params;
974   TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, *params, &host_params));
975 
976   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
977   TF_RETURN_IF_ERROR(DoScatter<CPUDevice, T, Index, Op>(
978       c, &host_params, host_indices, host_updates, num_indices));
979 
980   // Copy 'host_params' to device.
981   se::DeviceMemoryBase params_ptr(params->flat<T>().data(),
982                                   params->flat<T>().size() * sizeof(T));
983   stream->ThenMemcpy(&params_ptr, host_params.flat<T>().data(),
984                      host_params.NumElements() * sizeof(T));
985   if (!stream) {
986     return errors::Internal("Failed to copy params to device");
987   }
988   // Deallocate host_params' buffer once the host-to-device copy is complete.
989   // host_params is captured by value in the lambda so that its buffer is only
990   // destructed once the lambda is destructed.
991   c->device()->tensorflow_accelerator_device_info()->event_mgr->ThenExecute(
992       stream, [host_params] {});
993   return OkStatus();
994 }
995 
996 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
997 
998 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
DoScatter(OpKernelContext * c,Tensor * params,const Tensor & indices,const Tensor & updates,Index num_indices)999 Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
1000                  const Tensor& updates, Index num_indices) {
1001 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1002   if (std::is_same<Device, GPUDevice>::value &&
1003       tensorflow::OpDeterminismRequired()) {
1004     if (!DataTypeCanUseMemcpy(params->dtype())) {
1005       return errors::Unimplemented(
1006           "GPU Scatter ops for dtype ", DataTypeString(params->dtype()),
1007           " do not yet have a deterministic implementation");
1008     }
1009     return DoScatterOnCpu<T, Index, op>(c, params, indices, updates,
1010                                         num_indices);
1011   }
1012 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1013   auto indices_flat = indices.flat<Index>();
1014   auto params_flat = params->flat_outer_dims<T>();
1015   int64_t num_updates = updates.NumElements();
1016   auto updates_flat =
1017       updates.shaped<T, 2>({num_indices, num_updates / num_indices});
1018   functor::ScatterFunctor<Device, T, Index, op> functor;
1019   const Index bad_i = functor(c, c->template eigen_device<Device>(),
1020                               params_flat, updates_flat, indices_flat);
1021   if (bad_i >= 0) {
1022     return errors::InvalidArgument(
1023         "indices", SliceDebugString(indices.shape(), bad_i), " = ",
1024         indices_flat(bad_i), " is not in [0, ", params->dim_size(0), ")");
1025   }
1026   return OkStatus();
1027 }
1028 
1029 }  // namespace
1030 
1031 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
1032 class ResourceScatterUpdateOp : public OpKernel {
1033  public:
ResourceScatterUpdateOp(OpKernelConstruction * c)1034   explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
1035     // We use the same kernel for many operations.
1036     // Each operation has a different set of attributes defined in its nodes.
1037     Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
1038     if (!s.ok()) {
1039       use_exclusive_lock_ = false;
1040     }
1041   }
1042 
Compute(OpKernelContext * c)1043   void Compute(OpKernelContext* c) override {
1044     core::RefCountPtr<Var> v;
1045     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1046     OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
1047     const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
1048                                   c->input_dtype(0) == DT_STRING ||
1049                                   c->input_dtype(0) == DT_VARIANT;
1050     if (is_non_pod_dtype || use_exclusive_lock_) {
1051       mutex_lock ml(*v->mu());
1052       DoCompute(c);
1053     } else {
1054       // For POD dtypes, we can safely run the update without the mutex.
1055       tf_shared_lock ml(*v->mu());
1056       DoCompute(c);
1057     }
1058   }
1059 
1060  private:
1061   bool use_exclusive_lock_;
1062 
DoCompute(OpKernelContext * c)1063   void DoCompute(OpKernelContext* c) {
1064     core::RefCountPtr<Var> v;
1065     OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1066     Tensor* params = v->tensor();
1067     const Tensor& indices = c->input(1);
1068     const Tensor& updates = c->input(2);
1069 
1070     // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
1071     OP_REQUIRES(c,
1072                 updates.dims() == 0 ||
1073                     updates.dims() == indices.dims() + params->dims() - 1,
1074                 errors::InvalidArgument(
1075                     "Must have updates.shape = indices.shape + "
1076                     "params.shape[1:] or updates.shape = [], got ",
1077                     "updates.shape ", updates.shape().DebugString(),
1078                     ", indices.shape ", indices.shape().DebugString(),
1079                     ", params.shape ", params->shape().DebugString()));
1080 
1081     // Check that we have enough index space
1082     const int64_t N_big = indices.NumElements();
1083     OP_REQUIRES(
1084         c, N_big <= std::numeric_limits<Index>::max(),
1085         errors::InvalidArgument("indices has too many elements for ",
1086                                 DataTypeString(DataTypeToEnum<Index>::v()),
1087                                 " indexing: ", N_big, " > ",
1088                                 std::numeric_limits<Index>::max()));
1089     const Index N = static_cast<Index>(N_big);
1090     OP_REQUIRES(
1091         c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
1092         errors::InvalidArgument("params.shape[0] too large for ",
1093                                 DataTypeString(DataTypeToEnum<Index>::v()),
1094                                 " indexing: ", params->dim_size(0), " > ",
1095                                 std::numeric_limits<Index>::max()));
1096 
1097     // Prevent division by 0
1098     if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
1099       OP_REQUIRES(c, ValidateInput<T>(updates),
1100                   errors::InvalidArgument("updates must not contain 0"));
1101     }
1102 
1103     if (N > 0) {
1104       auto indices_flat = indices.flat<Index>();
1105       auto params_flat = params->flat_outer_dims<T>();
1106       if (TensorShapeUtils::IsScalar(updates.shape())) {
1107         const auto update = updates.scalar<T>();
1108 
1109         functor::ScatterScalarFunctor<Device, T, Index, op> functor;
1110         const Index bad_i = functor(c, c->template eigen_device<Device>(),
1111                                     params_flat, update, indices_flat);
1112         OP_REQUIRES(c, bad_i < 0,
1113                     errors::InvalidArgument(
1114                         "indices", SliceDebugString(indices.shape(), bad_i),
1115                         " = ", indices_flat(bad_i), " is not in [0, ",
1116                         params->dim_size(0), ")"));
1117       } else {
1118         OP_REQUIRES(
1119             c, TensorShapeUtils::StartsWith(updates.shape(), indices.shape()),
1120             errors::InvalidArgument(
1121                 "The shape of indices (", indices.shape().DebugString(),
1122                 ") must be a prefix of the shape of updates (",
1123                 updates.shape().DebugString(), ")"));
1124         OP_REQUIRES_OK(
1125             c, DoScatter<Device, T, Index, op>(c, params, indices, updates, N));
1126       }
1127     }
1128   }
1129 };
1130 
1131 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
1132   REGISTER_KERNEL_BUILDER(                                             \
1133       Name(name)                                                       \
1134           .Device(DEVICE_##dev)                                        \
1135           .HostMemory("resource")                                      \
1136           .TypeConstraint<type>("dtype")                               \
1137           .TypeConstraint<index_type>("Tindices"),                     \
1138       ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
1139 
1140 #define REGISTER_SCATTER_KERNEL(type, dev, name, op)         \
1141   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
1142   REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op);
1143 
1144 #define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
1145   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
1146                           scatter_op::UpdateOp::ADD);         \
1147   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
1148                           scatter_op::UpdateOp::SUB);         \
1149   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
1150                           scatter_op::UpdateOp::MUL);         \
1151   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
1152                           scatter_op::UpdateOp::DIV);         \
1153   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
1154                           scatter_op::UpdateOp::ASSIGN);
1155 #define REGISTER_SCATTER_MINMAX(type, dev)                 \
1156   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
1157                           scatter_op::UpdateOp::MIN);      \
1158   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
1159                           scatter_op::UpdateOp::MAX);
1160 
1161 // Registers CPU kernels.
1162 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
1163   REGISTER_SCATTER_ARITHMETIC(type, CPU);
1164 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
1165 
1166 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
1167 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
1168 
1169 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
1170                         scatter_op::UpdateOp::ASSIGN);
1171 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
1172                         scatter_op::UpdateOp::ASSIGN);
1173 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
1174                         scatter_op::UpdateOp::ASSIGN);
1175 
1176 // Registers GPU kernels.
1177 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1178 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
1179   REGISTER_SCATTER_ARITHMETIC(type, GPU);
1180 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
1181 
1182 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
1183 
1184 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
1185 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
1186 
1187 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1188                             .Device(DEVICE_DEFAULT)
1189                             .HostMemory("resource")
1190                             .HostMemory("indices")
1191                             .TypeConstraint<Variant>("dtype")
1192                             .TypeConstraint<int32>("Tindices"),
1193                         ResourceScatterUpdateOp<CPUDevice, Variant, int32,
1194                                                 scatter_op::UpdateOp::ASSIGN>)
1195 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1196                             .Device(DEVICE_GPU)
1197                             .HostMemory("resource")
1198                             .TypeConstraint<bool>("dtype")
1199                             .TypeConstraint<int32>("Tindices"),
1200                         ResourceScatterUpdateOp<GPUDevice, bool, int32,
1201                                                 scatter_op::UpdateOp::ASSIGN>)
1202 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1203                             .Device(DEVICE_DEFAULT)
1204                             .HostMemory("resource")
1205                             .HostMemory("indices")
1206                             .TypeConstraint<Variant>("dtype")
1207                             .TypeConstraint<int64_t>("Tindices"),
1208                         ResourceScatterUpdateOp<CPUDevice, Variant, int64,
1209                                                 scatter_op::UpdateOp::ASSIGN>)
1210 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1211                             .Device(DEVICE_GPU)
1212                             .HostMemory("resource")
1213                             .TypeConstraint<int64_t>("dtype")
1214                             .TypeConstraint<int64_t>("Tindices"),
1215                         ResourceScatterUpdateOp<GPUDevice, int64, int64,
1216                                                 scatter_op::UpdateOp::ASSIGN>)
1217 
1218 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1219 
1220 #undef REGISTER_SCATTER_ARITHMETIC
1221 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1222 #undef REGISTER_SCATTER_MINMAX
1223 #undef REGISTER_SCATTER_MINMAX_CPU
1224 #undef REGISTER_SCATTER_KERNEL
1225 #undef REGISTER_SCATTER_KERNEL_INDEX
1226 
1227 }  // namespace tensorflow
1228