xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conditional_accumulator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
17 #define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
18 
19 #include "tensorflow/core/kernels/fill_functor.h"
20 #include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
21 
22 namespace tensorflow {
23 
24 /**
25  * An aggregation object for adding dense gradients.
26  *
27  * The two main methods of this class are TryApplyGrad and TryTakeGrad.
28  *
29  * TryApplyGrad tries add a gradient to the accumulator. The attempt is
30  * successful if local_step >= global_step, i.e., if the gradient is not stale,
31  * having been computed using up-to-date information. Otherwise, the gradient is
32  * silently dropped.
33  *
34  * TryTakeGrad logs an attempt to read the average gradient. The attempt is
35  * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
36  * or exceeds the number requested by TryTakeGrad.
37  * Once this condition is satisfied, the following actions are taken:
38  * (1) the value of the average gradient is returned
39  * (2) the count of accumulated gradients is reset to 0
40  * (3) the internal global_step value (current_global_step_) is incremented by 1
41  *
42  * ConditionalAccumulator is the datatype-dependent templated sub-class of
43  * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that
44  * are used by for aggregating, averaging, allocating, returning dense Tensors.
45  */
46 template <typename Device, typename T>
47 class ConditionalAccumulator
48     : public TypedConditionalAccumulatorBase<const Tensor> {
49  public:
50   // Args:
51   //   dtype: The datatype of the gradients to be accumulated.
52   //   shape: The shape of the accumulated gradients.
53   //   name:  A name to use for the ConditionalAccumulator.
54   //   reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType & dtype,const PartialTensorShape & shape,const string & name,const string & reduction_type)55   ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
56                          const string& name, const string& reduction_type)
57       : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
58                                                       reduction_type) {}
~ConditionalAccumulator()59   ~ConditionalAccumulator() override{};
60 
61  protected:
62   // accum_grad is the tensor that holds the aggregate gradient.
63   // It is initialized the first time ApplyGrad is called.
64   Tensor accum_grad_;
65 
66   functor::SetZeroFunctor<Device, T> set_zero_functor_;
67 
ValidateShape(const Tensor * tensor)68   Status ValidateShape(const Tensor* tensor)
69       TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
70     // Must be compatible with accumulated gradient if available
71     if (counter_ > 0) {
72       if (!accum_grad_.shape().IsSameSize(tensor->shape())) {
73         return errors::InvalidArgument("Shape mismatch: expected ",
74                                        accum_grad_.shape().DebugString(),
75                                        ", got ", tensor->shape().DebugString());
76       }
77     }
78     // Must also be compatible with given shape
79     if (!shape_.IsCompatibleWith(tensor->shape())) {
80       return errors::InvalidArgument("Shape mismatch: expected ",
81                                      shape_.DebugString(), ", got ",
82                                      tensor->shape().DebugString());
83     }
84     return OkStatus();
85   }
86 
AllocateAndAssignToAccumGradFunction(OpKernelContext * ctx,const Tensor * grad)87   void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx,
88                                             const Tensor* grad) override {
89     // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
90     ctx->allocate_temp(dtype_, grad->shape(), &accum_grad_).IgnoreError();
91     accum_grad_.flat<T>().device(ctx->template eigen_device<Device>()) =
92         grad->flat<T>();
93   }
94 
AddToAccumGradFunction(OpKernelContext * ctx,const Tensor * grad)95   void AddToAccumGradFunction(OpKernelContext* ctx,
96                               const Tensor* grad) override {
97     accum_grad_.flat<T>().device(ctx->template eigen_device<Device>()) +=
98         grad->flat<T>();
99   }
100 
DivideAccumGradByCounter(OpKernelContext * ctx)101   void DivideAccumGradByCounter(OpKernelContext* ctx) override
102       TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
103     Tensor c(DataTypeToEnum<T>::value, {});
104     c.scalar<T>()() = TypeConverter<T, int>::ConvertUToT(this->counter_);
105     this->accum_grad_.template flat<T>().device(
106         ctx->template eigen_device<Device>()) =
107         this->accum_grad_.template flat<T>() / c.scalar<T>()();
108   }
109 
SetOutput(OpKernelContext * ctx)110   bool SetOutput(OpKernelContext* ctx) override {
111     ctx->set_output(0, accum_grad_);
112     return true;
113   }
114 
GetAndValidateTensorInputForApplyGrad(OpKernelContext * ctx,const Tensor ** tensor)115   bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx,
116                                              const Tensor** tensor) override
117       TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
118     // Get input gradient tensor
119     const Tensor* grad_tensor;
120     OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor));
121     *tensor = grad_tensor;
122     OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor));
123     return true;
124   }
125 
CleanUpGradTensor(const Tensor * tensor)126   void CleanUpGradTensor(const Tensor* tensor) override {
127     // do nothing
128   }
129 
130   TF_DISALLOW_COPY_AND_ASSIGN(ConditionalAccumulator);
131 };
132 
133 }  // namespace tensorflow
134 
135 #endif  // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
136