1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Our general strategy for preventing conflicts between concurrent
17 // reads and writes of resource variables is to:
18 // * For read operations, we:
19 // - acquire the variable's mutex (in "shared" mode);
20 // - make a (shallow) copy of the Tensor object, which increments
21 // the reference count on the variable's TensorBuffer;
22 // - release the variable's mutex;
23 // - use the copy of the Tensor object to do the read.
24 // * For write operations, we:
25 // - acquire the variable's mutex (in "exclusive" mode);
26 // - check the reference count of variable's TensorBuffer and
27 // if it is >1, make a deep copy of the variable's Tensor;
28 // - mutate the variable's Tensor;
29 // - and release the variable's mutex.
30 // This allows several read operations to all use the same
31 // TensorBuffer without needing to copy. When it comes time to write
32 // it will only make a copy if there is an outstanding read using the
33 // buffer. Write operations are serialized by the variable's mutex.
34 //
35 // For sparse operations (scatter, gather, sparse optimizer updates),
36 // we need to avoid copies, since there may not be enough memory for
37 // to copies of the whole tensor. To support this, we make two
38 // modifications to the above strategy:
39 // * For sparse reads (gather), we hold the variable's mutex (still in
40 // "shared" mode) for the duration of the whole read. This means
41 // that as long as you only do sparse read operations no write will
42 // see the reference count >1.
43 // * For sparse write operations where the user explicitly specifies
44 // that they want to perform the write without locks held
45 // (use_locking=false), we never copy even if the variable's
46 // reference count is >1.
47
48 #define EIGEN_USE_THREADS
49
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 #define EIGEN_USE_GPU
52 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
53 #include "tensorflow/core/platform/stream_executor.h"
54 #endif
55
56 #include <memory>
57 #include <vector>
58
59 #include "absl/strings/str_join.h"
60 #include "tensorflow/core/common_runtime/device.h"
61 #include "tensorflow/core/framework/bounds_check.h"
62 #include "tensorflow/core/framework/op_kernel.h"
63 #include "tensorflow/core/framework/register_types.h"
64 #include "tensorflow/core/framework/resource_mgr.h"
65 #include "tensorflow/core/framework/tensor_shape.h"
66 #include "tensorflow/core/framework/tensor_types.h"
67 #include "tensorflow/core/framework/variant_op_registry.h"
68 #include "tensorflow/core/kernels/dense_update_functor.h"
69 #include "tensorflow/core/kernels/gather_functor.h"
70 #include "tensorflow/core/kernels/gather_nd_op.h"
71 #include "tensorflow/core/kernels/resource_variable_ops.h"
72 #include "tensorflow/core/kernels/resource_variable_util.h"
73 #include "tensorflow/core/kernels/scatter_functor.h"
74 #include "tensorflow/core/kernels/training_op_helpers.h"
75 #include "tensorflow/core/kernels/variable_ops.h"
76 #include "tensorflow/core/lib/core/errors.h"
77 #include "tensorflow/core/lib/core/refcount.h"
78 #include "tensorflow/core/platform/casts.h"
79 #include "tensorflow/core/platform/mem.h"
80 #include "tensorflow/core/platform/mutex.h"
81 #include "tensorflow/core/platform/types.h"
82 #include "tensorflow/core/util/determinism.h"
83 #include "tensorflow/core/util/util.h"
84
85 namespace tensorflow {
86
87 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
88 ResourceHandlesOp<Var>);
89
ReadVariableOp(OpKernelConstruction * c)90 ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
91 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
92 }
93
94 namespace {
95
CopyVariable(int output_idx,OpKernelContext * ctx,const Tensor * t)96 Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
97 Tensor* output;
98 Notification n;
99 Status status;
100 AllocatorAttributes attr;
101 if (t->dtype() == DT_VARIANT) {
102 attr.set_on_host(true);
103 }
104 TF_RETURN_IF_ERROR(
105 ctx->allocate_output(output_idx, t->shape(), &output, attr));
106 if (t->dtype() == DT_VARIANT) {
107 output->flat<Variant>() = t->flat<Variant>();
108 } else if (ctx->op_device_context() != nullptr) {
109 // TODO(apassos): remove the down_cast by just returning Device* from
110 // OpKernelContext
111 Device* device = down_cast<Device*>(ctx->device());
112 ctx->op_device_context()->CopyTensorInSameDevice(
113 t, device, output, [&n, &status](const Status& s) {
114 status = s;
115 n.Notify();
116 });
117 n.WaitForNotification();
118 return status;
119 } else {
120 switch (t->dtype()) {
121 #define HANDLER(type) \
122 case DataTypeToEnum<type>::value: \
123 output->flat<type>() = t->flat<type>(); \
124 break;
125 TF_CALL_ALL_TYPES(HANDLER);
126 #undef HANDLER
127 default:
128 return errors::Internal("Unsupported dtype", t->dtype());
129 }
130 }
131 return OkStatus();
132 }
133
134 } // namespace
135
Compute(OpKernelContext * ctx)136 void ReadVariableOp::Compute(OpKernelContext* ctx) {
137 core::RefCountPtr<Var> variable;
138 const ResourceHandle& handle = HandleFromInput(ctx, 0);
139 const auto status = LookupResource(ctx, handle, &variable);
140 OP_REQUIRES(ctx, status.ok(),
141 errors::FailedPrecondition(
142 "Could not find variable ", handle.name(), ". ",
143 "This could mean that the variable has been deleted. ",
144 "In TF1, it can also mean the variable is uninitialized. ",
145 "Debug info: container=", handle.container(),
146 ", status error message=", status.error_message()));
147
148 tf_shared_lock ml(*variable->mu());
149 // We're acquiring a reference to the underlying buffer while
150 // holding a shared lock to guarantee ordering of reads and
151 // writes when in copy-on-write mode.
152 const Tensor* t = variable->tensor();
153 if (!variable->copy_on_read_mode.load()) {
154 OP_REQUIRES(
155 ctx, dtype_ == t->dtype(),
156 errors::InvalidArgument(
157 "Trying to read variable with wrong dtype. Expected ",
158 DataTypeString(dtype_), " got ", DataTypeString(t->dtype())));
159 ctx->set_output(0, *t);
160 } else {
161 OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t));
162 }
163 }
164
ReadVariablesOp(OpKernelConstruction * c)165 ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
166 int n;
167 OP_REQUIRES_OK(c, c->GetAttr("N", &n));
168 OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
169 OP_REQUIRES(c, n == dtypes_.size(),
170 errors::InvalidArgument(
171 "Mismatched number of arguments to ReadVariablesOp (", n,
172 " vs. ", dtypes_.size(), ")"));
173 }
174
Compute(OpKernelContext * ctx)175 void ReadVariablesOp::Compute(OpKernelContext* ctx) {
176 std::vector<core::RefCountPtr<Var>> variables(dtypes_.size());
177 std::vector<const ResourceHandle*> handles(dtypes_.size());
178 for (size_t i = 0; i < dtypes_.size(); ++i) {
179 handles[i] = &HandleFromInput(ctx, i);
180 }
181
182 OP_REQUIRES_OK(ctx, LookupResources(ctx, handles, &variables));
183
184 std::vector<string> uninitialized_vars;
185 for (int64_t i = 0; i < variables.size(); i++) {
186 if (variables[i] == nullptr) {
187 uninitialized_vars.push_back(handles[i]->name());
188 }
189 }
190
191 OP_REQUIRES(ctx, uninitialized_vars.empty(),
192 errors::FailedPrecondition(
193 "In ReadVariablesOp the following variables were "
194 "found uninitialized: ",
195 absl::StrJoin(uninitialized_vars, ", ")));
196
197 for (size_t i = 0; i < dtypes_.size(); ++i) {
198 // We're acquiring a reference to the underlying buffer while
199 // holding a shared lock to guarantee ordering of reads and
200 // writes.
201 tf_shared_lock ml(*variables[i]->mu());
202 OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(),
203 errors::InvalidArgument(
204 "Trying to read variable ", handles[i]->name(),
205 " from Container: ", handles[i]->container(),
206 " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
207 " got ", DataTypeString(variables[i]->tensor()->dtype())));
208 if (variables[i]->copy_on_read_mode.load()) {
209 OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor()));
210 } else {
211 const Tensor& t = *variables[i]->tensor();
212 ctx->set_output(i, t);
213 }
214 }
215 }
216
217 REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
218 ReadVariableOp);
219 REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
220 ReadVariablesOp);
221
222 REGISTER_KERNEL_BUILDER(
223 Name("ReadVariableOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
224 ReadVariableOp);
225 REGISTER_KERNEL_BUILDER(
226 Name("_ReadVariablesOp").Device(DEVICE_DEFAULT).HostMemory("resources"),
227 ReadVariablesOp);
228
VarHandleOp(OpKernelConstruction * context)229 VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
230 OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
231 OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
232
233 OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
234 OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
235
236 is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
237
238 // Use const_tensor_ if the variable is non-anonymous.
239 if (!is_anonymous_) {
240 AllocatorAttributes attr;
241 attr.set_on_host(true);
242 OP_REQUIRES_OK(context, context->allocate_temp(DT_RESOURCE, TensorShape({}),
243 &const_tensor_, attr));
244 const_tensor_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
245 context, container_, name_,
246 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
247 }
248 }
249
Compute(OpKernelContext * ctx)250 void VarHandleOp::Compute(OpKernelContext* ctx) {
251 if (is_anonymous_) {
252 Var* resource = new Var(dtype_and_shape_.dtype);
253 ResourceMgr* mgr = ctx->resource_manager();
254 ResourceHandle handle = ResourceHandle::MakeRefCountingHandle<Var>(
255 resource, ctx->device()->name(),
256 std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
257 ctx->stack_trace());
258 // TODO(b/203901837): See if we can abolish all code paths that lookup
259 // anonymous variables and then stop publishing them to the manager.
260 OP_REQUIRES_OK(ctx, mgr->CreateUnowned<Var>(handle.container(),
261 handle.name(), resource));
262
263 AllocatorAttributes attr;
264 attr.set_on_host(true);
265 Tensor tensor;
266 OP_REQUIRES_OK(
267 ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &tensor, attr));
268
269 tensor.scalar<ResourceHandle>()() = std::move(handle);
270
271 ctx->set_output(0, tensor);
272 } else {
273 ctx->set_output(0, const_tensor_);
274 }
275 }
276
277 REGISTER_KERNEL_BUILDER(Name("VarHandleOp").Device(DEVICE_CPU), VarHandleOp);
278
279 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
280 #define REGISTER_GPU_KERNELS(type) \
281 namespace functor { \
282 template <> \
283 void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
284 const GPUDevice& d, typename TTypes<type>::Flat lhs, \
285 typename TTypes<type>::ConstFlat rhs); \
286 extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
287 }
288
289 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
290 TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
291 TF_CALL_int64(REGISTER_GPU_KERNELS);
292 TF_CALL_variant(REGISTER_GPU_KERNELS);
293 TF_CALL_uint32(REGISTER_GPU_KERNELS);
294 #undef REGISTER_GPU_KERNELS
295
296 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
297
298 #define REGISTER_DEFAULT_KERNELS(type) \
299 REGISTER_KERNEL_BUILDER(Name("VarHandleOp") \
300 .Device(DEVICE_DEFAULT) \
301 .HostMemory("resource") \
302 .TypeConstraint<type>("dtype"), \
303 VarHandleOp)
304 TF_CALL_GPU_ALL_TYPES(REGISTER_DEFAULT_KERNELS);
305 TF_CALL_bfloat16(REGISTER_DEFAULT_KERNELS);
306 TF_CALL_int64(REGISTER_DEFAULT_KERNELS);
307 TF_CALL_variant(REGISTER_DEFAULT_KERNELS);
308 TF_CALL_uint32(REGISTER_DEFAULT_KERNELS);
309 #undef REGISTER_DEFAULT_KERNELS
310
311 REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
312 .Device(DEVICE_DEFAULT)
313 .HostMemory("resources")
314 .TypeConstraint("dtypes",
315 {DT_INT64, DT_COMPLEX64,
316 DT_COMPLEX128, DT_HALF, DT_FLOAT,
317 DT_DOUBLE, DT_BOOL, DT_VARIANT}),
318 ResourceHandlesOp<Var>);
319
320 REGISTER_KERNEL_BUILDER(
321 Name("VariableShape").Device(DEVICE_CPU).TypeConstraint<int32>("out_type"),
322 VariableShapeOp<int32>);
323 REGISTER_KERNEL_BUILDER(Name("VariableShape")
324 .Device(DEVICE_CPU)
325 .TypeConstraint<int64_t>("out_type"),
326 VariableShapeOp<int64_t>);
327
328 REGISTER_KERNEL_BUILDER(Name("VariableShape")
329 .Device(DEVICE_DEFAULT)
330 .TypeConstraint<int32>("out_type")
331 .HostMemory("output")
332 .HostMemory("input"),
333 VariableShapeOp<int32>);
334 REGISTER_KERNEL_BUILDER(Name("VariableShape")
335 .Device(DEVICE_DEFAULT)
336 .TypeConstraint<int64_t>("out_type")
337 .HostMemory("output")
338 .HostMemory("input"),
339 VariableShapeOp<int64_t>);
340
DestroyResourceOp(OpKernelConstruction * ctx)341 DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx)
342 : OpKernel(ctx) {
343 OP_REQUIRES_OK(ctx,
344 ctx->GetAttr("ignore_lookup_error", &ignore_lookup_error_));
345 }
346
Compute(OpKernelContext * ctx)347 void DestroyResourceOp::Compute(OpKernelContext* ctx) {
348 const ResourceHandle& p = HandleFromInput(ctx, 0);
349 Status status = DeleteResource(ctx, p);
350 if (ignore_lookup_error_ && errors::IsNotFound(status)) {
351 return;
352 }
353 OP_REQUIRES_OK(ctx, status);
354 }
355
356 REGISTER_KERNEL_BUILDER(Name("DestroyResourceOp").Device(DEVICE_CPU),
357 DestroyResourceOp);
358 REGISTER_KERNEL_BUILDER(
359 Name("DestroyResourceOp").Device(DEVICE_DEFAULT).HostMemory("resource"),
360 DestroyResourceOp);
361
Compute(OpKernelContext * ctx)362 void DisableCopyOnReadOp::Compute(OpKernelContext* ctx) {
363 core::RefCountPtr<Var> variable;
364 const ResourceHandle& handle = HandleFromInput(ctx, 0);
365 const auto status = LookupResource(ctx, handle, &variable);
366 OP_REQUIRES(ctx, status.ok(),
367 errors::FailedPrecondition(
368 "Could not find variable ", handle.name(), ". ",
369 "This could mean that the variable has been deleted. ",
370 "In TF1, it can also mean the variable is uninitialized. ",
371 "Debug info: container=", handle.container(),
372 ", status error message=", status.error_message()));
373 // If the variable is currently in copy-on-read mode, its refcount is 1
374 if (variable->copy_on_read_mode.load()) {
375 // Obtain an exclusive lock on the variable and change the access mode
376 mutex_lock ml(*variable->mu());
377 variable->copy_on_read_mode.store(false);
378 }
379 }
380
381 REGISTER_KERNEL_BUILDER(Name("DisableCopyOnRead").Device(DEVICE_CPU),
382 DisableCopyOnReadOp);
383 REGISTER_KERNEL_BUILDER(
384 Name("DisableCopyOnRead").Device(DEVICE_DEFAULT).HostMemory("resource"),
385 DisableCopyOnReadOp);
386
387 template <typename Device, typename T>
388 class AssignVariableOp : public OpKernel {
389 public:
AssignVariableOp(OpKernelConstruction * c)390 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
391 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
392 if (!c->GetAttr("_grappler_relax_allocator_constraints",
393 &relax_constraints_)
394 .ok()) {
395 relax_constraints_ = false;
396 }
397 if (c->HasAttr("validate_shape")) {
398 OP_REQUIRES_OK(c, c->GetAttr("validate_shape", &validate_shape_));
399 }
400 }
401
Compute(OpKernelContext * context)402 void Compute(OpKernelContext* context) override {
403 OP_REQUIRES(context, dtype_ == context->input(1).dtype(),
404 errors::InvalidArgument(
405 "Variable and value dtypes don't match; respectively, ",
406 DataTypeString(dtype_), " and ",
407 DataTypeString(context->input(1).dtype())));
408 core::RefCountPtr<Var> variable;
409 const Tensor& value = context->input(1);
410 // Note: every resource-variable-manipulating op assumes copy-on-write
411 // semantics, and creates a copy of the variable's Tensor if its refcount is
412 // bigger than 1 when we try to modify it. This means we never need to copy
413 // the original tensor for AssignVariableOp; even if there are other live
414 // users of it we know none can modify it so this is always safe (even in
415 // esoteric cases where the same tensor is used to initialize multiple
416 // variables or the tensor is a constant this is safe, as future writes will
417 // trigger copies).
418 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
419 context, HandleFromInput(context, 0), &variable,
420 [this, &value](Var** ptr) {
421 *ptr = new Var(dtype_);
422 *(*ptr)->tensor() = value;
423 (*ptr)->is_initialized = true;
424 return OkStatus();
425 }));
426 mutex_lock ml(*variable->mu());
427 // (variable->tensor()->dtype() == DT_INVALID && !variable->is_initialized)
428 // check below is to allow an XLA specific situation wherein update can
429 // happen first by the AssignVariableOp,
430 // in which case the variable is still uninitialized.
431 // When using TF-XLA, this scenario is possible when the execution uses the
432 // 'fallback' path (which essentially invokes Tensorflow ops via
433 // partitioned_call).
434 OP_REQUIRES(context,
435 (variable->tensor()->dtype() == DT_INVALID &&
436 !variable->is_initialized) ||
437 variable->tensor()->dtype() == dtype_,
438 errors::InvalidArgument(
439 "Trying to assign variable with wrong dtype. Expected ",
440 DataTypeString(variable->tensor()->dtype()), " got ",
441 DataTypeString(dtype_)));
442 if (validate_shape_) {
443 OP_REQUIRES(
444 context,
445 (!variable->is_initialized ||
446 variable->tensor()->shape().IsSameSize(value.shape())),
447 errors::InvalidArgument(
448 "Trying to assign to variable with tensor with wrong shape."
449 " Expected ",
450 variable->tensor()->shape().DebugString(), " got ",
451 value.shape().DebugString()));
452 }
453 if (variable->copy_on_read_mode.load()) {
454 AllocatorAttributes attr;
455 attr.set_gpu_compatible(true);
456 attr.set_nic_compatible(true);
457 OP_REQUIRES_OK(context,
458 context->allocate_temp(value.dtype(), value.shape(),
459 variable->tensor(), attr));
460 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
461 copy_functor(context->eigen_device<Device>(),
462 variable->tensor()->flat<T>(), value.flat<T>());
463 } else {
464 *variable->tensor() = value;
465 }
466 variable->is_initialized = true;
467 }
468
469 private:
470 DataType dtype_;
471 bool relax_constraints_;
472 bool validate_shape_ = false;
473 };
474
475 template <typename Device>
476 class AssignVariableOp<Device, Variant> : public OpKernel {
477 public:
AssignVariableOp(OpKernelConstruction * c)478 explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
479 OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
480 OP_REQUIRES(c, dtype_ == DT_VARIANT,
481 errors::Internal("Variant kernel called with dtype: ",
482 DataTypeString(dtype_)));
483 }
484
Compute(OpKernelContext * context)485 void Compute(OpKernelContext* context) override {
486 const Tensor& value = context->input(1);
487 core::RefCountPtr<Var> variable;
488 OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
489 context, HandleFromInput(context, 0), &variable,
490 [](Var** ptr) {
491 // Created on host.
492 *ptr = new Var(DT_VARIANT);
493 return OkStatus();
494 }));
495
496 // For purposes of forwarding DT_VARIANT, we want the least
497 // restrictive attr; we already know the input is on host.
498 AllocatorAttributes attr;
499
500 // Copying is unnecessary if we are the last user of the value
501 // tensor, we can just adopt the input tensor's buffer instead.
502 // Note that Variant objects themselves always reside on host.
503 //
504 // We nevertheless want to signal to the runtime that the tensor
505 // should reside in memory of the associated device, as Variant
506 // tensors may be marked as sitting on either CPU or GPU. This
507 // helps to elide one or more copies.
508 std::unique_ptr<Tensor> input_alias = context->forward_input(
509 1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
510 value.shape(),
511 DEVICE_MEMORY /* HOST_MEMORY is only reserved for special cases */,
512 attr);
513
514 mutex_lock ml(*variable->mu());
515 OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
516 errors::InvalidArgument(
517 "Trying to assign variable with wrong dtype. Expected ",
518 DataTypeString(variable->tensor()->dtype()), " got ",
519 DataTypeString(DT_VARIANT)));
520 variable->is_initialized = true;
521 *variable->tensor() = Tensor(DT_VARIANT, value.shape());
522
523 if (input_alias) {
524 *variable->tensor() = *input_alias;
525 return;
526 }
527
528 // Need to copy, but maybe we can re-use variable's buffer?
529 if (!variable->tensor()->RefCountIsOne() ||
530 !variable->tensor()->shape().IsSameSize(value.shape())) {
531 // Allocation of DT_VARIANT is always on host.
532 attr.set_on_host(true);
533 OP_REQUIRES_OK(context, context->allocate_temp(DT_VARIANT, value.shape(),
534 variable->tensor(), attr));
535 }
536
537 const auto elements_in = value.flat<Variant>();
538 auto elements_out = variable->tensor()->flat<Variant>();
539 for (int64_t i = 0; i < elements_in.size(); ++i) {
540 elements_out(i) = elements_in(i);
541 }
542 }
543
544 private:
545 DataType dtype_;
546 };
547
548 #define REGISTER_KERNELS(type) \
549 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
550 .Device(DEVICE_CPU) \
551 .TypeConstraint<type>("dtype"), \
552 AssignVariableOp<Eigen::ThreadPoolDevice, type>);
553
554 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
555 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
556 #undef REGISTER_KERNELS
557
558 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
559 #define REGISTER_GPU_KERNELS(type) \
560 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
561 .Device(DEVICE_GPU) \
562 .TypeConstraint<type>("dtype") \
563 .HostMemory("resource"), \
564 AssignVariableOp<GPUDevice, type>);
565
566 TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
567 TF_CALL_bfloat16(REGISTER_GPU_KERNELS);
568 TF_CALL_int64(REGISTER_GPU_KERNELS);
569 TF_CALL_uint32(REGISTER_GPU_KERNELS);
570 #undef REGISTER_GPU_KERNELS
571 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
572
573 #define REGISTER_KERNELS(type) \
574 REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \
575 .Device(DEVICE_DEFAULT) \
576 .TypeConstraint<type>("dtype") \
577 .HostMemory("resource"), \
578 AssignVariableOp<CPUDevice, type>);
579
580 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
581 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
582 #undef REGISTER_KERNELS
583
584 template <typename Device, typename T, DenseUpdateType Op>
585 class AssignUpdateVariableOp : public OpKernel {
586 public:
AssignUpdateVariableOp(OpKernelConstruction * c)587 explicit AssignUpdateVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
588
Compute(OpKernelContext * context)589 void Compute(OpKernelContext* context) override {
590 core::RefCountPtr<Var> variable;
591 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
592 &variable));
593
594 const Tensor& value = context->input(1);
595 // TODO(apassos): We could possibly avoid the copy done by
596 // PrepareToUpdateVariable() for commutative operations like Op ==
597 // ADD if value's refcount was 1.
598 mutex_lock ml(*variable->mu());
599 Tensor* var_tensor = variable->tensor();
600 OP_REQUIRES_OK(context, ValidateAssignUpdateVariableOpShapes(
601 var_tensor->shape(), value.shape()));
602 OP_REQUIRES_OK(
603 context, PrepareToUpdateVariable<Device, T>(
604 context, var_tensor, variable->copy_on_read_mode.load()));
605 functor::DenseUpdate<Device, T, Op> update_functor;
606 update_functor(context->eigen_device<Device>(), var_tensor->flat<T>(),
607 value.flat<T>());
608 }
609 };
610
611 #define REGISTER_KERNELS(type) \
612 REGISTER_KERNEL_BUILDER( \
613 Name("AssignAddVariableOp") \
614 .Device(DEVICE_CPU) \
615 .TypeConstraint<type>("dtype"), \
616 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, ADD>); \
617 REGISTER_KERNEL_BUILDER( \
618 Name("AssignSubVariableOp") \
619 .Device(DEVICE_CPU) \
620 .TypeConstraint<type>("dtype"), \
621 AssignUpdateVariableOp<Eigen::ThreadPoolDevice, type, SUB>);
622
623 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
624 #undef REGISTER_KERNELS
625
626 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
627 #define REGISTER_GPU_KERNELS(type) \
628 REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \
629 .Device(DEVICE_GPU) \
630 .HostMemory("resource") \
631 .TypeConstraint<type>("dtype"), \
632 AssignUpdateVariableOp<GPUDevice, type, ADD>); \
633 REGISTER_KERNEL_BUILDER(Name("AssignSubVariableOp") \
634 .Device(DEVICE_GPU) \
635 .HostMemory("resource") \
636 .TypeConstraint<type>("dtype"), \
637 AssignUpdateVariableOp<GPUDevice, type, SUB>);
638
639 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
640 TF_CALL_int64(REGISTER_GPU_KERNELS);
641 #undef REGISTER_GPU_KERNELS
642 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
643
644 class VarIsInitializedOp : public OpKernel {
645 public:
VarIsInitializedOp(OpKernelConstruction * c)646 explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {}
647
Compute(OpKernelContext * context)648 void Compute(OpKernelContext* context) override {
649 Tensor* output = nullptr;
650 OP_REQUIRES_OK(context,
651 context->allocate_output(0, TensorShape({}), &output));
652 auto output_tensor = output->tensor<bool, 0>();
653 core::RefCountPtr<Var> variable;
654 Status s = LookupResource(context, HandleFromInput(context, 0), &variable);
655 if (!s.ok()) {
656 output_tensor() = false;
657 return;
658 }
659 mutex_lock ml(*variable->mu());
660 output_tensor() = variable->is_initialized;
661 }
662 };
663
664 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU),
665 VarIsInitializedOp);
666
667 REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp")
668 .Device(DEVICE_DEFAULT)
669 .HostMemory("resource")
670 .HostMemory("is_initialized"),
671 VarIsInitializedOp);
672
673 template <typename Device, typename T, typename Index>
674 class ResourceGatherOp : public OpKernel {
675 public:
ResourceGatherOp(OpKernelConstruction * c)676 explicit ResourceGatherOp(OpKernelConstruction* c) : OpKernel(c) {
677 OP_REQUIRES_OK(c, c->GetAttr("batch_dims", &batch_dims_));
678 }
679
Compute(OpKernelContext * c)680 void Compute(OpKernelContext* c) override {
681 core::RefCountPtr<Var> v;
682 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
683 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
684 // NOTE: We hold the lock for the whole gather operation instead
685 // of increasing the reference count of v->tensor() to avoid a
686 // situation where a write to the same variable will see a
687 // reference count greater than one and make a copy of the
688 // (potentially very large) tensor buffer.
689 tf_shared_lock ml(*v->mu());
690 const Tensor& params = *v->tensor();
691 const Tensor& indices = c->input(1);
692 OP_REQUIRES(
693 c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
694 errors::InvalidArgument("params must be at least 1 dimensional"));
695 OP_REQUIRES(
696 c, params.shape().dims() >= batch_dims_,
697 errors::InvalidArgument("params must have at least ", batch_dims_,
698 " (batch_dims) dimensions but it has shape ",
699 params.shape().DebugString()));
700
701 // Check that we have enough index space
702 const int64_t N = indices.NumElements();
703 OP_REQUIRES(
704 c, params.dim_size(0) <= std::numeric_limits<Index>::max(),
705 errors::InvalidArgument("params.shape[0] too large for ",
706 DataTypeString(DataTypeToEnum<Index>::v()),
707 " indexing: ", params.dim_size(0), " > ",
708 std::numeric_limits<Index>::max()));
709
710 // The result shape is params.shape[:batch_dims] +
711 // indices.shape[batch_dims:] + params.shape[batch_dims+1:].
712 TensorShape result_shape;
713 for (int i = 0; i < batch_dims_; ++i) {
714 result_shape.AddDim(params.dim_size(i));
715 }
716 for (int i = batch_dims_; i < indices.dims(); ++i) {
717 result_shape.AddDim(indices.dim_size(i));
718 }
719 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
720 result_shape.AddDim(params.dim_size(i));
721 }
722
723 Tensor* out = nullptr;
724 Tensor tmp;
725 if (params.dtype() == DT_VARIANT) {
726 tmp = Tensor(DT_VARIANT, result_shape);
727 c->set_output(0, tmp);
728 out = &tmp;
729 } else {
730 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
731 }
732
733 if (N > 0) {
734 Tensor tmp_indices;
735
736 // Points to the original or updated (if batch_dims is set) indices.
737 const Tensor* op_indices = &indices;
738 if (batch_dims_ > 0) {
739 OP_REQUIRES_OK(c, c->allocate_temp(indices.dtype(), indices.shape(),
740 &tmp_indices));
741 functor::DenseUpdate<Device, Index, ASSIGN> copy_functor;
742 copy_functor(c->eigen_device<Device>(), tmp_indices.flat<Index>(),
743 indices.flat<Index>());
744
745 AddBatchOffsets(c, &tmp_indices, params);
746 if (!c->status().ok()) return;
747 op_indices = &tmp_indices;
748 }
749
750 int64_t gather_dim_size = 1;
751 for (int idx = 0; idx <= batch_dims_; ++idx) {
752 gather_dim_size *= params.dim_size(idx);
753 }
754 int64_t inner_size = 1;
755 for (int i = batch_dims_ + 1; i < params.dims(); ++i) {
756 inner_size *= params.dim_size(i);
757 }
758 auto params_flat = params.shaped<T, 3>({1, gather_dim_size, inner_size});
759 const auto indices_flat = op_indices->flat<Index>();
760 auto out_flat = out->shaped<T, 3>({1, N, out->NumElements() / N});
761
762 functor::GatherFunctor<Device, T, Index> functor;
763 int64_t bad_i = functor(c, params_flat, indices_flat, out_flat);
764
765 OP_REQUIRES(
766 c, bad_i < 0,
767 errors::InvalidArgument(
768 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
769 indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
770 }
771 }
772
773 private:
774 // Add the batch offset derived from params to each batch of indices.
775 // Example: batch_dims = 1, indices = [[0, 1, 2], [0, 1, 2]]
776 // If indexing into a params dimension of size 4, then the indices will become
777 // [0, 1, 2, 4, 5, 6]
AddBatchOffsets(OpKernelContext * ctx,Tensor * indices,const Tensor & params)778 void AddBatchOffsets(OpKernelContext* ctx, Tensor* indices,
779 const Tensor& params) {
780 int64_t batch_size = 1; // The size of all batch dimensions.
781 for (int idx = 0; idx < batch_dims_; ++idx) {
782 batch_size *= params.dim_size(idx);
783 }
784 OP_REQUIRES(
785 ctx, batch_size != 0,
786 errors::InvalidArgument(
787 "Inner size of indices would result in batch_size of 0 and a ",
788 "division by 0 in the implementation. This is illegal"));
789
790 auto indices_flat = indices->flat<Index>();
791 int64_t const index_inner_size = indices->NumElements() / batch_size;
792 int64_t const batch_offset = params.dim_size(batch_dims_);
793 for (int64_t batch_idx = 0, dest_idx = 0; batch_idx < batch_size;
794 ++batch_idx) {
795 for (int64_t idx = 0; idx < index_inner_size; ++idx) {
796 indices_flat(dest_idx++) += batch_offset * batch_idx;
797 }
798 }
799 }
800
801 int32 batch_dims_ = 0;
802 };
803
804 #define REGISTER_GATHER_FULL(dev, type, index_type) \
805 REGISTER_KERNEL_BUILDER(Name("ResourceGather") \
806 .Device(DEVICE_##dev) \
807 .HostMemory("resource") \
808 .TypeConstraint<type>("dtype") \
809 .TypeConstraint<index_type>("Tindices"), \
810 ResourceGatherOp<dev##Device, type, index_type>)
811
812 #define REGISTER_GATHER_ALL_INDICES(dev, type) \
813 REGISTER_GATHER_FULL(dev, type, int32); \
814 REGISTER_GATHER_FULL(dev, type, int64_t)
815
816 #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
817
818 // Registration of the CPU implementations.
819 TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
820 TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
821
822 // Registers GPU kernels.
823 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
824 #define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type)
825
826 TF_CALL_int64(REGISTER_GATHER_GPU);
827 TF_CALL_GPU_ALL_TYPES(REGISTER_GATHER_GPU);
828
829 // Variant objects themselves sit on CPU, even if they contain data
830 // pointing to a device.
831 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
832 .Device(DEVICE_DEFAULT)
833 .HostMemory("resource")
834 .HostMemory("indices")
835 .TypeConstraint<Variant>("dtype")
836 .TypeConstraint<int32>("Tindices"),
837 ResourceGatherOp<CPUDevice, Variant, int32>)
838 REGISTER_KERNEL_BUILDER(Name("ResourceGather")
839 .Device(DEVICE_DEFAULT)
840 .HostMemory("resource")
841 .HostMemory("indices")
842 .TypeConstraint<Variant>("dtype")
843 .TypeConstraint<int64_t>("Tindices"),
844 ResourceGatherOp<CPUDevice, Variant, int64>)
845
846 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
847
848 #undef REGISTER_GATHER_CPU
849 #undef REGISTER_GATHER_GPU
850 #undef REGISTER_GATHER_ALL_INDICES
851 #undef REGISTER_GATHER_FULL
852
853 template <typename Device, typename T, typename Index>
854 class ResourceGatherNdOp : public OpKernel {
855 public:
ResourceGatherNdOp(OpKernelConstruction * c)856 explicit ResourceGatherNdOp(OpKernelConstruction* c) : OpKernel(c) {}
857
Compute(OpKernelContext * c)858 void Compute(OpKernelContext* c) override {
859 core::RefCountPtr<Var> v;
860 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
861 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
862 // NOTE: We hold the lock for the whole gather operation instead
863 // of increasing the reference count of v->tensor() to avoid a
864 // situation where a write to the same variable will see a
865 // reference count greater than one and make a copy of the
866 // (potentially very large) tensor buffer.
867 tf_shared_lock ml(*v->mu());
868 const Tensor& params = *v->tensor();
869 const Tensor& indices = c->input(1);
870
871 Tensor out;
872 OP_REQUIRES_OK(
873 c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
874 c->set_output(0, out);
875 }
876 };
877
878 #define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
879 REGISTER_KERNEL_BUILDER(Name("ResourceGatherNd") \
880 .Device(DEVICE_##dev) \
881 .HostMemory("resource") \
882 .TypeConstraint<type>("dtype") \
883 .TypeConstraint<index_type>("Tindices"), \
884 ResourceGatherNdOp<dev##Device, type, index_type>)
885
886 #define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
887 REGISTER_GATHER_ND_FULL(dev, type, int32); \
888 REGISTER_GATHER_ND_FULL(dev, type, int64_t)
889
890 #define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
891
892 // Registration of the CPU implementations.
893 TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
894
895 // Registers GPU kernels.
896 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
897 #define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
898
899 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
900
901 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
902
903 #undef REGISTER_GATHER_ND_CPU
904 #undef REGISTER_GATHER_ND_GPU
905 #undef REGISTER_GATHER_ND_ALL_INDICES
906 #undef REGISTER_GATHER_ND_FULL
907
908 namespace {
909
910 template <typename Device>
isCPUDevice()911 bool isCPUDevice() {
912 return false;
913 }
914
915 template <>
isCPUDevice()916 bool isCPUDevice<CPUDevice>() {
917 return true;
918 }
919
920 template <typename T>
ValidateInput(const Tensor & updates)921 bool ValidateInput(const Tensor& updates) {
922 const auto updates_flat = updates.flat<T>();
923 for (int i = 0; i < updates.NumElements(); ++i) {
924 if (updates_flat(i) == T{}) return false;
925 }
926 return true;
927 }
928
929 template <>
ValidateInput(const Tensor & updates)930 bool ValidateInput<Variant>(const Tensor& updates) {
931 return true;
932 }
933
934 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
935 Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
936 const Tensor& updates, Index num_indices);
937
938 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
939
940 template <typename T>
CopyTensorToHost(OpKernelContext * c,const Tensor & device_tensor,Tensor * host_tensor)941 Status CopyTensorToHost(OpKernelContext* c, const Tensor& device_tensor,
942 Tensor* host_tensor) {
943 AllocatorAttributes alloc_attr;
944 alloc_attr.set_on_host(true);
945 alloc_attr.set_gpu_compatible(true);
946 auto stream = c->op_device_context()->stream();
947 TF_RETURN_IF_ERROR(c->allocate_temp(
948 device_tensor.dtype(), device_tensor.shape(), host_tensor, alloc_attr));
949 se::DeviceMemoryBase device_ptr(
950 const_cast<Tensor&>(device_tensor).flat<T>().data(),
951 device_tensor.flat<T>().size() * sizeof(T));
952 stream->ThenMemcpy(host_tensor->flat<T>().data(), device_ptr,
953 device_tensor.NumElements() * sizeof(T));
954 if (!stream) {
955 return errors::Internal("Failed to copy indices to host");
956 }
957 return OkStatus();
958 }
959
960 // Copies inputs to the CPU, runs DoScatter on the CPU, then copies output
961 // back to GPU. This is useful because the CPU implementation is deterministic
962 // and the GPU implementation is not. Tensor inputs to this function must be on
963 // the GPU.
964 template <typename T, typename Index, scatter_op::UpdateOp Op>
DoScatterOnCpu(OpKernelContext * c,Tensor * params,const Tensor & indices,const Tensor & updates,Index num_indices)965 Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices,
966 const Tensor& updates, Index num_indices) {
967 auto stream = c->op_device_context()->stream();
968
969 Tensor host_indices;
970 TF_RETURN_IF_ERROR(CopyTensorToHost<Index>(c, indices, &host_indices));
971 Tensor host_updates;
972 TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, updates, &host_updates));
973 Tensor host_params;
974 TF_RETURN_IF_ERROR(CopyTensorToHost<T>(c, *params, &host_params));
975
976 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
977 TF_RETURN_IF_ERROR(DoScatter<CPUDevice, T, Index, Op>(
978 c, &host_params, host_indices, host_updates, num_indices));
979
980 // Copy 'host_params' to device.
981 se::DeviceMemoryBase params_ptr(params->flat<T>().data(),
982 params->flat<T>().size() * sizeof(T));
983 stream->ThenMemcpy(¶ms_ptr, host_params.flat<T>().data(),
984 host_params.NumElements() * sizeof(T));
985 if (!stream) {
986 return errors::Internal("Failed to copy params to device");
987 }
988 // Deallocate host_params' buffer once the host-to-device copy is complete.
989 // host_params is captured by value in the lambda so that its buffer is only
990 // destructed once the lambda is destructed.
991 c->device()->tensorflow_accelerator_device_info()->event_mgr->ThenExecute(
992 stream, [host_params] {});
993 return OkStatus();
994 }
995
996 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
997
998 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
DoScatter(OpKernelContext * c,Tensor * params,const Tensor & indices,const Tensor & updates,Index num_indices)999 Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices,
1000 const Tensor& updates, Index num_indices) {
1001 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1002 if (std::is_same<Device, GPUDevice>::value &&
1003 tensorflow::OpDeterminismRequired()) {
1004 if (!DataTypeCanUseMemcpy(params->dtype())) {
1005 return errors::Unimplemented(
1006 "GPU Scatter ops for dtype ", DataTypeString(params->dtype()),
1007 " do not yet have a deterministic implementation");
1008 }
1009 return DoScatterOnCpu<T, Index, op>(c, params, indices, updates,
1010 num_indices);
1011 }
1012 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1013 auto indices_flat = indices.flat<Index>();
1014 auto params_flat = params->flat_outer_dims<T>();
1015 int64_t num_updates = updates.NumElements();
1016 auto updates_flat =
1017 updates.shaped<T, 2>({num_indices, num_updates / num_indices});
1018 functor::ScatterFunctor<Device, T, Index, op> functor;
1019 const Index bad_i = functor(c, c->template eigen_device<Device>(),
1020 params_flat, updates_flat, indices_flat);
1021 if (bad_i >= 0) {
1022 return errors::InvalidArgument(
1023 "indices", SliceDebugString(indices.shape(), bad_i), " = ",
1024 indices_flat(bad_i), " is not in [0, ", params->dim_size(0), ")");
1025 }
1026 return OkStatus();
1027 }
1028
1029 } // namespace
1030
1031 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
1032 class ResourceScatterUpdateOp : public OpKernel {
1033 public:
ResourceScatterUpdateOp(OpKernelConstruction * c)1034 explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
1035 // We use the same kernel for many operations.
1036 // Each operation has a different set of attributes defined in its nodes.
1037 Status s = c->GetAttr("use_locking", &use_exclusive_lock_);
1038 if (!s.ok()) {
1039 use_exclusive_lock_ = false;
1040 }
1041 }
1042
Compute(OpKernelContext * c)1043 void Compute(OpKernelContext* c) override {
1044 core::RefCountPtr<Var> v;
1045 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1046 OP_REQUIRES_OK(c, EnsureSparseVariableAccess<Device, T>(c, v.get()));
1047 const bool is_non_pod_dtype = c->input_dtype(0) == DT_RESOURCE ||
1048 c->input_dtype(0) == DT_STRING ||
1049 c->input_dtype(0) == DT_VARIANT;
1050 if (is_non_pod_dtype || use_exclusive_lock_) {
1051 mutex_lock ml(*v->mu());
1052 DoCompute(c);
1053 } else {
1054 // For POD dtypes, we can safely run the update without the mutex.
1055 tf_shared_lock ml(*v->mu());
1056 DoCompute(c);
1057 }
1058 }
1059
1060 private:
1061 bool use_exclusive_lock_;
1062
DoCompute(OpKernelContext * c)1063 void DoCompute(OpKernelContext* c) {
1064 core::RefCountPtr<Var> v;
1065 OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
1066 Tensor* params = v->tensor();
1067 const Tensor& indices = c->input(1);
1068 const Tensor& updates = c->input(2);
1069
1070 // Check that rank(updates.shape) = rank(indices.shape + params.shape[1:])
1071 OP_REQUIRES(c,
1072 updates.dims() == 0 ||
1073 updates.dims() == indices.dims() + params->dims() - 1,
1074 errors::InvalidArgument(
1075 "Must have updates.shape = indices.shape + "
1076 "params.shape[1:] or updates.shape = [], got ",
1077 "updates.shape ", updates.shape().DebugString(),
1078 ", indices.shape ", indices.shape().DebugString(),
1079 ", params.shape ", params->shape().DebugString()));
1080
1081 // Check that we have enough index space
1082 const int64_t N_big = indices.NumElements();
1083 OP_REQUIRES(
1084 c, N_big <= std::numeric_limits<Index>::max(),
1085 errors::InvalidArgument("indices has too many elements for ",
1086 DataTypeString(DataTypeToEnum<Index>::v()),
1087 " indexing: ", N_big, " > ",
1088 std::numeric_limits<Index>::max()));
1089 const Index N = static_cast<Index>(N_big);
1090 OP_REQUIRES(
1091 c, params->dim_size(0) <= std::numeric_limits<Index>::max(),
1092 errors::InvalidArgument("params.shape[0] too large for ",
1093 DataTypeString(DataTypeToEnum<Index>::v()),
1094 " indexing: ", params->dim_size(0), " > ",
1095 std::numeric_limits<Index>::max()));
1096
1097 // Prevent division by 0
1098 if (isCPUDevice<Device>() && op == tensorflow::scatter_op::UpdateOp::DIV) {
1099 OP_REQUIRES(c, ValidateInput<T>(updates),
1100 errors::InvalidArgument("updates must not contain 0"));
1101 }
1102
1103 if (N > 0) {
1104 auto indices_flat = indices.flat<Index>();
1105 auto params_flat = params->flat_outer_dims<T>();
1106 if (TensorShapeUtils::IsScalar(updates.shape())) {
1107 const auto update = updates.scalar<T>();
1108
1109 functor::ScatterScalarFunctor<Device, T, Index, op> functor;
1110 const Index bad_i = functor(c, c->template eigen_device<Device>(),
1111 params_flat, update, indices_flat);
1112 OP_REQUIRES(c, bad_i < 0,
1113 errors::InvalidArgument(
1114 "indices", SliceDebugString(indices.shape(), bad_i),
1115 " = ", indices_flat(bad_i), " is not in [0, ",
1116 params->dim_size(0), ")"));
1117 } else {
1118 OP_REQUIRES(
1119 c, TensorShapeUtils::StartsWith(updates.shape(), indices.shape()),
1120 errors::InvalidArgument(
1121 "The shape of indices (", indices.shape().DebugString(),
1122 ") must be a prefix of the shape of updates (",
1123 updates.shape().DebugString(), ")"));
1124 OP_REQUIRES_OK(
1125 c, DoScatter<Device, T, Index, op>(c, params, indices, updates, N));
1126 }
1127 }
1128 }
1129 };
1130
1131 #define REGISTER_SCATTER_KERNEL_INDEX(type, index_type, dev, name, op) \
1132 REGISTER_KERNEL_BUILDER( \
1133 Name(name) \
1134 .Device(DEVICE_##dev) \
1135 .HostMemory("resource") \
1136 .TypeConstraint<type>("dtype") \
1137 .TypeConstraint<index_type>("Tindices"), \
1138 ResourceScatterUpdateOp<dev##Device, type, index_type, op>)
1139
1140 #define REGISTER_SCATTER_KERNEL(type, dev, name, op) \
1141 REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
1142 REGISTER_SCATTER_KERNEL_INDEX(type, int64_t, dev, name, op);
1143
1144 #define REGISTER_SCATTER_ARITHMETIC(type, dev) \
1145 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd", \
1146 scatter_op::UpdateOp::ADD); \
1147 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub", \
1148 scatter_op::UpdateOp::SUB); \
1149 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul", \
1150 scatter_op::UpdateOp::MUL); \
1151 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv", \
1152 scatter_op::UpdateOp::DIV); \
1153 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
1154 scatter_op::UpdateOp::ASSIGN);
1155 #define REGISTER_SCATTER_MINMAX(type, dev) \
1156 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
1157 scatter_op::UpdateOp::MIN); \
1158 REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
1159 scatter_op::UpdateOp::MAX);
1160
1161 // Registers CPU kernels.
1162 #define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
1163 REGISTER_SCATTER_ARITHMETIC(type, CPU);
1164 #define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
1165
1166 TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
1167 TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
1168
1169 REGISTER_SCATTER_KERNEL(tstring, CPU, "ResourceScatterUpdate",
1170 scatter_op::UpdateOp::ASSIGN);
1171 REGISTER_SCATTER_KERNEL(bool, CPU, "ResourceScatterUpdate",
1172 scatter_op::UpdateOp::ASSIGN);
1173 REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
1174 scatter_op::UpdateOp::ASSIGN);
1175
1176 // Registers GPU kernels.
1177 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1178 #define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
1179 REGISTER_SCATTER_ARITHMETIC(type, GPU);
1180 #define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
1181
1182 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
1183
1184 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
1185 TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
1186
1187 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1188 .Device(DEVICE_DEFAULT)
1189 .HostMemory("resource")
1190 .HostMemory("indices")
1191 .TypeConstraint<Variant>("dtype")
1192 .TypeConstraint<int32>("Tindices"),
1193 ResourceScatterUpdateOp<CPUDevice, Variant, int32,
1194 scatter_op::UpdateOp::ASSIGN>)
1195 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1196 .Device(DEVICE_GPU)
1197 .HostMemory("resource")
1198 .TypeConstraint<bool>("dtype")
1199 .TypeConstraint<int32>("Tindices"),
1200 ResourceScatterUpdateOp<GPUDevice, bool, int32,
1201 scatter_op::UpdateOp::ASSIGN>)
1202 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1203 .Device(DEVICE_DEFAULT)
1204 .HostMemory("resource")
1205 .HostMemory("indices")
1206 .TypeConstraint<Variant>("dtype")
1207 .TypeConstraint<int64_t>("Tindices"),
1208 ResourceScatterUpdateOp<CPUDevice, Variant, int64,
1209 scatter_op::UpdateOp::ASSIGN>)
1210 REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
1211 .Device(DEVICE_GPU)
1212 .HostMemory("resource")
1213 .TypeConstraint<int64_t>("dtype")
1214 .TypeConstraint<int64_t>("Tindices"),
1215 ResourceScatterUpdateOp<GPUDevice, int64, int64,
1216 scatter_op::UpdateOp::ASSIGN>)
1217
1218 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1219
1220 #undef REGISTER_SCATTER_ARITHMETIC
1221 #undef REGISTER_SCATTER_ARITHMETIC_CPU
1222 #undef REGISTER_SCATTER_MINMAX
1223 #undef REGISTER_SCATTER_MINMAX_CPU
1224 #undef REGISTER_SCATTER_KERNEL
1225 #undef REGISTER_SCATTER_KERNEL_INDEX
1226
1227 } // namespace tensorflow
1228