1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 17 #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 18 19 #include "tensorflow/core/kernels/fill_functor.h" 20 #include "tensorflow/core/kernels/typed_conditional_accumulator_base.h" 21 22 namespace tensorflow { 23 24 /** 25 * An aggregation object for adding dense gradients. 26 * 27 * The two main methods of this class are TryApplyGrad and TryTakeGrad. 28 * 29 * TryApplyGrad tries add a gradient to the accumulator. The attempt is 30 * successful if local_step >= global_step, i.e., if the gradient is not stale, 31 * having been computed using up-to-date information. Otherwise, the gradient is 32 * silently dropped. 33 * 34 * TryTakeGrad logs an attempt to read the average gradient. The attempt is 35 * blocked until the number of gradients accumulated (via TryApplyGrad) is equal 36 * or exceeds the number requested by TryTakeGrad. 37 * Once this condition is satisfied, the following actions are taken: 38 * (1) the value of the average gradient is returned 39 * (2) the count of accumulated gradients is reset to 0 40 * (3) the internal global_step value (current_global_step_) is incremented by 1 41 * 42 * ConditionalAccumulator is the datatype-dependent templated sub-class of 43 * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that 44 * are used by for aggregating, averaging, allocating, returning dense Tensors. 45 */ 46 template <typename Device, typename T> 47 class ConditionalAccumulator 48 : public TypedConditionalAccumulatorBase<const Tensor> { 49 public: 50 // Args: 51 // dtype: The datatype of the gradients to be accumulated. 52 // shape: The shape of the accumulated gradients. 53 // name: A name to use for the ConditionalAccumulator. 54 // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType & dtype,const PartialTensorShape & shape,const string & name,const string & reduction_type)55 ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, 56 const string& name, const string& reduction_type) 57 : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name, 58 reduction_type) {} ~ConditionalAccumulator()59 ~ConditionalAccumulator() override{}; 60 61 protected: 62 // accum_grad is the tensor that holds the aggregate gradient. 63 // It is initialized the first time ApplyGrad is called. 64 Tensor accum_grad_; 65 66 functor::SetZeroFunctor<Device, T> set_zero_functor_; 67 ValidateShape(const Tensor * tensor)68 Status ValidateShape(const Tensor* tensor) 69 TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 70 // Must be compatible with accumulated gradient if available 71 if (counter_ > 0) { 72 if (!accum_grad_.shape().IsSameSize(tensor->shape())) { 73 return errors::InvalidArgument("Shape mismatch: expected ", 74 accum_grad_.shape().DebugString(), 75 ", got ", tensor->shape().DebugString()); 76 } 77 } 78 // Must also be compatible with given shape 79 if (!shape_.IsCompatibleWith(tensor->shape())) { 80 return errors::InvalidArgument("Shape mismatch: expected ", 81 shape_.DebugString(), ", got ", 82 tensor->shape().DebugString()); 83 } 84 return OkStatus(); 85 } 86 AllocateAndAssignToAccumGradFunction(OpKernelContext * ctx,const Tensor * grad)87 void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx, 88 const Tensor* grad) override { 89 // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! 90 ctx->allocate_temp(dtype_, grad->shape(), &accum_grad_).IgnoreError(); 91 accum_grad_.flat<T>().device(ctx->template eigen_device<Device>()) = 92 grad->flat<T>(); 93 } 94 AddToAccumGradFunction(OpKernelContext * ctx,const Tensor * grad)95 void AddToAccumGradFunction(OpKernelContext* ctx, 96 const Tensor* grad) override { 97 accum_grad_.flat<T>().device(ctx->template eigen_device<Device>()) += 98 grad->flat<T>(); 99 } 100 DivideAccumGradByCounter(OpKernelContext * ctx)101 void DivideAccumGradByCounter(OpKernelContext* ctx) override 102 TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 103 Tensor c(DataTypeToEnum<T>::value, {}); 104 c.scalar<T>()() = TypeConverter<T, int>::ConvertUToT(this->counter_); 105 this->accum_grad_.template flat<T>().device( 106 ctx->template eigen_device<Device>()) = 107 this->accum_grad_.template flat<T>() / c.scalar<T>()(); 108 } 109 SetOutput(OpKernelContext * ctx)110 bool SetOutput(OpKernelContext* ctx) override { 111 ctx->set_output(0, accum_grad_); 112 return true; 113 } 114 GetAndValidateTensorInputForApplyGrad(OpKernelContext * ctx,const Tensor ** tensor)115 bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx, 116 const Tensor** tensor) override 117 TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 118 // Get input gradient tensor 119 const Tensor* grad_tensor; 120 OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor)); 121 *tensor = grad_tensor; 122 OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor)); 123 return true; 124 } 125 CleanUpGradTensor(const Tensor * tensor)126 void CleanUpGradTensor(const Tensor* tensor) override { 127 // do nothing 128 } 129 130 TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulator); 131 }; 132 133 } // namespace tensorflow 134 135 #endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ 136