xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/gradients/array_grad.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/ops/array_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 namespace ops {
26 namespace {
27 
28 REGISTER_NO_GRADIENT_OP("Const");
29 REGISTER_NO_GRADIENT_OP("StopGradient");
30 REGISTER_NO_GRADIENT_OP("ConcatOffset");
31 REGISTER_NO_GRADIENT_OP("EditDistance");
32 REGISTER_NO_GRADIENT_OP("ZerosLike");
33 REGISTER_NO_GRADIENT_OP("InvertPermutation");
34 REGISTER_NO_GRADIENT_OP("Shape");
35 REGISTER_NO_GRADIENT_OP("ShapeN");
36 REGISTER_NO_GRADIENT_OP("Rank");
37 REGISTER_NO_GRADIENT_OP("Size");
38 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39 REGISTER_NO_GRADIENT_OP("OneHot");
40 
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)41 Status PackGrad(const Scope& scope, const Operation& op,
42                 const std::vector<Output>& grad_inputs,
43                 std::vector<Output>* grad_outputs) {
44   int N;
45   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46   int axis;
47   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48 
49   grad_outputs->reserve(N);
50   auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51   for (const Output& o : grad_op.output) {
52     grad_outputs->emplace_back(o);
53   }
54   return scope.status();
55 }
56 REGISTER_GRADIENT_OP("Pack", PackGrad);
57 
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)58 Status UnpackGrad(const Scope& scope, const Operation& op,
59                   const std::vector<Output>& grad_inputs,
60                   std::vector<Output>* grad_outputs) {
61   int axis;
62   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63   grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64   return scope.status();
65 }
66 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67 
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)68 Status IdentityGrad(const Scope& scope, const Operation& op,
69                     const std::vector<Output>& grad_inputs,
70                     std::vector<Output>* grad_outputs) {
71   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72   return scope.status();
73 }
74 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75 
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status RefIdentityGrad(const Scope& scope, const Operation& op,
77                        const std::vector<Output>& grad_inputs,
78                        std::vector<Output>* grad_outputs) {
79   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80   return scope.status();
81 }
82 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83 
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)84 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85                                  const std::vector<Output>& grad_inputs,
86                                  std::vector<Output>* grad_outputs) {
87   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88   return scope.status();
89 }
90 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91 
QuantizeAndDequantizeV4GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)92 Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93                                          const Operation& op,
94                                          const std::vector<Output>& grad_inputs,
95                                          std::vector<Output>* grad_outputs) {
96   Input input = Shape(scope, op.input(0));
97   Input input_min = op.input(1);
98   Input input_max = op.input(2);
99   int64_t axis;
100   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101   auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102       scope, grad_inputs[0], input, input_min, input_max,
103       QuantizeAndDequantizeV4Grad::Axis(axis));
104   grad_outputs->push_back(qdq_v4_grad.input_backprop);
105   grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106   grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107   return scope.status();
108 }
109 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110                      QuantizeAndDequantizeV4GradHelper);
111 
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)112 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113                                    const std::vector<Output>& grad_inputs,
114                                    std::vector<Output>* grad_outputs) {
115   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116   grad_outputs->push_back(NoGradient());
117   grad_outputs->push_back(NoGradient());
118   grad_outputs->push_back(NoGradient());
119   return scope.status();
120 }
121 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122 
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status SplitGrad(const Scope& scope, const Operation& op,
124                  const std::vector<Output>& grad_inputs,
125                  std::vector<Output>* grad_outputs) {
126   grad_outputs->push_back(NoGradient());
127   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128   return scope.status();
129 }
130 REGISTER_GRADIENT_OP("Split", SplitGrad);
131 
SplitVGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)132 Status SplitVGrad(const Scope& scope, const Operation& op,
133                   const std::vector<Output>& grad_inputs,
134                   std::vector<Output>* grad_outputs) {
135   if (op.num_inputs() < 3) {
136     return errors::InvalidArgument("SplitV requires 3 arguments");
137   }
138   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2)));
139   for (int i = 0; i < op.num_inputs() - 1; ++i) {
140     grad_outputs->push_back(NoGradient());
141   }
142   return scope.status();
143 }
144 REGISTER_GRADIENT_OP("SplitV", SplitVGrad);
145 
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)146 Status FillGrad(const Scope& scope, const Operation& op,
147                 const std::vector<Output>& grad_inputs,
148                 std::vector<Output>* grad_outputs) {
149   // y = fill(fill_shape, x)
150   // No gradient returned for the fill_shape argument.
151   grad_outputs->push_back(NoGradient());
152   // The gradient for x (which must be a scalar) is just the sum of
153   // all the gradients from the shape it fills.
154   // We use ReduceSum to implement this, which needs an argument providing
155   // the indices of all the dimensions of the incoming gradient.
156   // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
157   auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
158                         Const(scope, 1));
159   grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
160   return scope.status();
161 }
162 REGISTER_GRADIENT_OP("Fill", FillGrad);
163 
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)164 Status DiagGrad(const Scope& scope, const Operation& op,
165                 const std::vector<Output>& grad_inputs,
166                 std::vector<Output>* grad_outputs) {
167   grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
168   return scope.status();
169 }
170 REGISTER_GRADIENT_OP("Diag", DiagGrad);
171 
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)172 Status DiagPartGrad(const Scope& scope, const Operation& op,
173                     const std::vector<Output>& grad_inputs,
174                     std::vector<Output>* grad_outputs) {
175   grad_outputs->push_back(Diag(scope, grad_inputs[0]));
176   return scope.status();
177 }
178 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
179 
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)180 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
181                       const std::vector<Output>& grad_inputs,
182                       std::vector<Output>* grad_outputs) {
183   grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
184   return scope.status();
185 }
186 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
187 
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)188 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
189                           const std::vector<Output>& grad_inputs,
190                           std::vector<Output>* grad_outputs) {
191   auto num_lower = op.input(1);
192   auto num_upper = op.input(2);
193   grad_outputs->push_back(
194       MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
195   grad_outputs->push_back(NoGradient());
196   grad_outputs->push_back(NoGradient());
197   return scope.status();
198 }
199 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
200 
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)201 Status GatherNdGrad(const Scope& scope, const Operation& op,
202                     const std::vector<Output>& grad_inputs,
203                     std::vector<Output>* grad_outputs) {
204   auto ref = op.input(0);
205   auto ref_shape = Shape(scope, ref);
206   auto indices = op.input(1);
207   grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
208   grad_outputs->push_back(NoGradient());
209   return scope.status();
210 }
211 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
212 
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
214                          const std::vector<Output>& grad_inputs,
215                          std::vector<Output>* grad_outputs) {
216   string message;
217   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
218   string err_msg = strings::StrCat(
219       "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
220       message);
221   grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
222   return scope.status();
223 }
224 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
225 
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)226 Status ReshapeGrad(const Scope& scope, const Operation& op,
227                    const std::vector<Output>& grad_inputs,
228                    std::vector<Output>* grad_outputs) {
229   auto input_shape = Shape(scope, op.input(0));
230   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
231   grad_outputs->push_back(NoGradient());
232   return scope.status();
233 }
234 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
235 
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)236 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
237                       const std::vector<Output>& grad_inputs,
238                       std::vector<Output>* grad_outputs) {
239   auto input_shape = Shape(scope, op.input(0));
240   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
241   grad_outputs->push_back(NoGradient());
242   return scope.status();
243 }
244 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
245 
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)246 Status SqueezeGrad(const Scope& scope, const Operation& op,
247                    const std::vector<Output>& grad_inputs,
248                    std::vector<Output>* grad_outputs) {
249   auto input_shape = Shape(scope, op.input(0));
250   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
251   return scope.status();
252 }
253 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
254 
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)255 Status TransposeGrad(const Scope& scope, const Operation& op,
256                      const std::vector<Output>& grad_inputs,
257                      std::vector<Output>* grad_outputs) {
258   auto inverted_perm = InvertPermutation(scope, op.input(1));
259   grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
260   grad_outputs->push_back(NoGradient());
261   return scope.status();
262 }
263 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
264 
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)265 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
266                            const std::vector<Output>& grad_inputs,
267                            std::vector<Output>* grad_outputs) {
268   auto seq_lengths = op.input(1);
269   int batch_dim;
270   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
271   int seq_dim;
272   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
273   grad_outputs->push_back(
274       ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
275                       ReverseSequence::BatchDim(batch_dim)));
276   grad_outputs->push_back(NoGradient());
277   return scope.status();
278 }
279 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
280 
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)281 Status ReverseGrad(const Scope& scope, const Operation& op,
282                    const std::vector<Output>& grad_inputs,
283                    std::vector<Output>* grad_outputs) {
284   auto reverse_dims = op.input(1);
285   grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
286   grad_outputs->push_back(NoGradient());
287   return scope.status();
288 }
289 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
290 
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)291 Status ScatterNdGrad(const Scope& scope, const Operation& op,
292                      const std::vector<Output>& grad_inputs,
293                      std::vector<Output>* grad_outputs) {
294   auto indices = op.input(0);
295   grad_outputs->push_back(NoGradient());
296   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
297   grad_outputs->push_back(NoGradient());
298   return scope.status();
299 }
300 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
301 
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)302 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
303                                    const std::vector<Output>& grad_inputs,
304                                    std::vector<Output>* grad_outputs) {
305   auto indices = op.input(1);
306   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
307   grad_outputs->push_back(NoGradient());
308   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
309   return scope.status();
310 }
311 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
312 
313 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)314 Status PadGrad(const Scope& scope, const Operation& op,
315                const std::vector<Output>& grad_inputs,
316                std::vector<Output>* grad_outputs) {
317   auto x = op.input(0);
318   auto a = op.input(1);  // [Rank(x), 2]
319   // Takes a slice of a. The 1st column. [Rank(x), 1].
320   auto size = Stack(scope, {Rank(scope, x), 1});
321   auto pad_before = Slice(scope, a, {0, 0}, size);
322   // Make it a 1-D tensor.
323   auto begin = Reshape(scope, pad_before, {-1});
324   grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
325   grad_outputs->push_back(NoGradient());
326   // PadV2 adds a "constant_values" input.
327   if (IsPadV2) {
328     grad_outputs->push_back(NoGradient());
329   }
330   return scope.status();
331 }
332 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
333 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
334 
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)335 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
336                         const std::vector<Output>& grad_inputs,
337                         std::vector<Output>* grad_outputs) {
338   int block_size;
339   TF_RETURN_IF_ERROR(
340       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
341   grad_outputs->push_back(
342       BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
343   grad_outputs->push_back(NoGradient());
344   return scope.status();
345 }
346 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
347 
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)348 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
349                           const std::vector<Output>& grad_inputs,
350                           std::vector<Output>* grad_outputs) {
351   grad_outputs->push_back(
352       BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
353   grad_outputs->push_back(NoGradient());
354   grad_outputs->push_back(NoGradient());
355   return scope.status();
356 }
357 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
358 
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
360                         const std::vector<Output>& grad_inputs,
361                         std::vector<Output>* grad_outputs) {
362   int block_size;
363   TF_RETURN_IF_ERROR(
364       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
365   grad_outputs->push_back(
366       SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
367   grad_outputs->push_back(NoGradient());
368   return scope.status();
369 }
370 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
371 
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)372 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
373                           const std::vector<Output>& grad_inputs,
374                           std::vector<Output>* grad_outputs) {
375   grad_outputs->push_back(
376       SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
377   grad_outputs->push_back(NoGradient());
378   grad_outputs->push_back(NoGradient());
379   return scope.status();
380 }
381 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
382 
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)383 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
384                         const std::vector<Output>& grad_inputs,
385                         std::vector<Output>* grad_outputs) {
386   int block_size;
387   TF_RETURN_IF_ERROR(
388       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
389   grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
390   return scope.status();
391 }
392 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
393 
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)394 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
395                         const std::vector<Output>& grad_inputs,
396                         std::vector<Output>* grad_outputs) {
397   int block_size;
398   TF_RETURN_IF_ERROR(
399       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
400   grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
401   return scope.status();
402 }
403 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
404 
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)405 Status MirrorPadGrad(const Scope& scope, const Operation& op,
406                      const std::vector<Output>& grad_inputs,
407                      std::vector<Output>* grad_outputs) {
408   string mode;
409   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
410   grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
411       scope, grad_inputs[0], op.input(1), mode));
412   grad_outputs->push_back(NoGradient());
413   return scope.status();
414 }
415 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
416 
417 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)418 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
419                          const std::vector<Output>& grad_inputs,
420                          std::vector<Output>* grad_outputs) {
421   string mode;
422   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
423   grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
424   grad_outputs->push_back(NoGradient());
425   return scope.status();
426 }
427 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
428 
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
430                               const std::vector<Output>& grad_inputs,
431                               std::vector<Output>* grad_outputs) {
432   Input x = Shape(scope, op.input(0));
433   Input begin = op.input(1);
434   Input end = op.input(2);
435   Input strides = op.input(3);
436   int64_t begin_mask;
437   int64_t end_mask;
438   int64_t ellipsis_mask;
439   int64_t new_axis_mask;
440   int64_t shrink_axis_mask;
441   TF_RETURN_IF_ERROR(
442       GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
443   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
444   TF_RETURN_IF_ERROR(
445       GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
446   TF_RETURN_IF_ERROR(
447       GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
448   TF_RETURN_IF_ERROR(
449       GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
450   grad_outputs->push_back(
451       StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
452                        StridedSliceGrad::BeginMask(begin_mask)
453                            .EndMask(end_mask)
454                            .EllipsisMask(ellipsis_mask)
455                            .NewAxisMask(new_axis_mask)
456                            .ShrinkAxisMask(shrink_axis_mask)));
457   // No gradients returned for begin, end and strides
458   grad_outputs->push_back(NoGradient());
459   grad_outputs->push_back(NoGradient());
460   grad_outputs->push_back(NoGradient());
461   return scope.status();
462 }
463 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
464 
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)465 Status SliceGrad(const Scope& scope, const Operation& op,
466                  const std::vector<Output>& grad_inputs,
467                  std::vector<Output>* grad_outputs) {
468   // Propagate the incoming gradient along all the selected values,
469   // and zero everywhere else. Use the Pad operator for this.
470   //
471   // First create an Nx2 padding where N is the number of input
472   // dimensions. The first column is the number of prepended zeros
473   // for each dimension, and the second column is the number of
474   // appended zeros.
475   //
476   // The first column is just the begin vector.
477   // The second column is the shape of the input element-wise
478   // subtracted by begin+size
479 
480   // Running example:
481   // input.shape = [3, 5, 3]
482   // begin = [1, 2, 1], size = [1, 3, 2]
483   Input input = op.input(0);
484   Input begin = op.input(1);
485   // input_rank = 3
486   auto input_rank = Rank(scope, input);
487   // slice_size = [1, 3, 2]
488   auto slice_size = Shape(scope, op.output(0));
489   // padding_shape = [3, 1]
490   auto padding_shape = Stack(scope, {input_rank, 1});
491   // before_padding = [[1]
492   //                   [2]
493   //                   [1]]
494   Input before_padding = Reshape(scope, begin, padding_shape);
495   // after_padding_sizes = shape(input) - slice_size - begin
496   //                     = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
497   //                     = [1, 0, 0]
498   auto after_padding_sizes =
499       Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
500   // after_padding = [[1]
501   //                  [0]
502   //                  [0]]
503   Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
504   // paddings = [[1 1]
505   //             [2 0]
506   //             [1 0]]
507   auto paddings =
508       Concat(scope, {before_padding, after_padding}, Const(scope, 1));
509   grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
510   // Nothing propagated for "begin" and "size" inputs
511   grad_outputs->push_back(NoGradient());
512   grad_outputs->push_back(NoGradient());
513   return scope.status();
514 }
515 REGISTER_GRADIENT_OP("Slice", SliceGrad);
516 
ConcatGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,int start_value_index,int end_value_index,int dim_index)517 Status ConcatGradHelper(const Scope& scope, const Operation& op,
518                         const std::vector<Output>& grad_inputs,
519                         std::vector<Output>* grad_outputs,
520                         int start_value_index, int end_value_index,
521                         int dim_index) {
522   if (end_value_index >= op.num_inputs()) {
523     return errors::Internal("Invalid input index");
524   }
525   std::vector<Output> inputs;
526   inputs.reserve(end_value_index - start_value_index);
527   for (int i = start_value_index; i < end_value_index; ++i) {
528     inputs.push_back(op.input(i));
529   }
530 
531   auto shapes = ShapeN(scope, inputs);
532   const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset");
533   auto builder =
534       ::tensorflow::NodeBuilder(unique_name, "ConcatOffset")
535           .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
536           .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
537   scope.UpdateBuilder(&builder);
538   ::tensorflow::Node* concat_offset_node;
539   scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node));
540   scope.UpdateStatus(scope.DoShapeInference(concat_offset_node));
541   if (concat_offset_node->num_outputs() != inputs.size()) {
542     return errors::Internal("ConcatOffset has invalid output count");
543   }
544   if (grad_inputs.size() != 1) {
545     return errors::InvalidArgument("Concat grad should have 1 input");
546   }
547 
548   // For each dx[i], we take a slice of dy. The offset and size of the
549   // slice is given by offset[i] and shape[i].
550   const Output& dy = grad_inputs[0];
551   for (int i = 0; i < inputs.size(); ++i) {
552     grad_outputs->push_back(
553         Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i]));
554   }
555 
556   // Insert a NoGradient for the axis.
557   grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient());
558   return scope.status();
559 }
560 
ConcatV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)561 Status ConcatV2Grad(const Scope& scope, const Operation& op,
562                     const std::vector<Output>& grad_inputs,
563                     std::vector<Output>* grad_outputs) {
564   return ConcatGradHelper(scope, op, grad_inputs, grad_outputs,
565                           /*start_value_index=*/0,
566                           /*end_value_index=*/op.num_inputs() - 1,
567                           /*dim+index=*/op.num_inputs() - 1);
568 }
569 
570 REGISTER_GRADIENT_OP("ConcatV2", ConcatV2Grad);
571 
BroadcastToGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)572 Status BroadcastToGrad(const Scope& scope, const Operation& op,
573                        const std::vector<Output>& grad_inputs,
574                        std::vector<Output>* grad_outputs) {
575   if (grad_inputs.size() != 1) {
576     return errors::InvalidArgument("BroadcastTo grad should have 1 grad input");
577   }
578   if (op.num_inputs() != 2) {
579     return errors::InvalidArgument("BroadcastTo requires 2 inputs");
580   }
581 
582   auto x_shape = Shape(scope, op.input(0));
583   auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1));
584   auto sum_gx = Sum(scope, grad_inputs[0], args.r0);
585   grad_outputs->push_back(Reshape(scope, sum_gx, x_shape));
586   grad_outputs->push_back(NoGradient());
587   return scope.status();
588 }
589 
590 REGISTER_GRADIENT_OP("BroadcastTo", BroadcastToGrad);
591 
TileGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)592 Status TileGrad(const Scope& scope, const Operation& op,
593                 const std::vector<Output>& grad_inputs,
594                 std::vector<Output>* grad_outputs) {
595   if (op.num_inputs() != 2) {
596     return errors::InvalidArgument("Tile requires 2 inputs");
597   }
598   if (grad_inputs.size() != 1) {
599     return errors::InvalidArgument("Tile grad requires 1 grad input");
600   }
601 
602   Shape::Attrs shape_attrs;
603   shape_attrs.out_type_ = op.input_type(1);
604   auto input_shape = Shape(scope, op.input(0), shape_attrs);
605   // We interleave multiples and input_shape to get split_shape,
606   // reshape grad to split_shape, and reduce along all even
607   // dimensions (the tiled dimensions) to get the result
608   // with shape input_shape.  For example
609   //   input_shape = [20, 30, 40]
610   //   multiples = [2, 3, 4]
611   //   split_shape = [2, 20, 3, 30, 4, 40]
612   //   axes = [0, 2, 4]
613   auto stack = Stack(scope, {op.input(1), input_shape.output});
614   auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1);
615   auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1});
616   auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2);
617   auto input_grad = ReduceSum(
618       scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output);
619   grad_outputs->push_back(input_grad.output);
620   grad_outputs->push_back(NoGradient());
621   return scope.status();
622 }
623 REGISTER_GRADIENT_OP("Tile", TileGrad);
624 
625 // Create a constant of the provided d_type;
ConstHelper(const Scope & scope,int value,DataType d_type)626 Output ConstHelper(const Scope& scope, int value, DataType d_type) {
627   return Cast(scope, Const(scope, value), d_type);
628 }
629 
630 // Adds the batch offsets to the given indices and returns the results.
GetBatchIndices(const Scope & scope,const Output & params_shape,const Output & indices,int batch_dims)631 Output GetBatchIndices(const Scope& scope, const Output& params_shape,
632                        const Output& indices, int batch_dims) {
633   Output batch_indices = indices;
634   auto indices_ndims = Rank(scope, indices);
635   auto casted_params_shape = Cast(scope, params_shape, indices.type());
636   Output accum_dim_value = ConstHelper(scope, 1, indices.type());
637   for (int dim = batch_dims; dim > 0; dim--) {
638     Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1});
639     accum_dim_value = Multiply(scope, accum_dim_value,
640                                Slice(scope, casted_params_shape, {dim}, {1}));
641     auto start = ConstHelper(scope, 0, indices.type());
642     auto step = ConstHelper(scope, 1, indices.type());
643     Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step);
644     dim_indices = Multiply(scope, dim_indices, accum_dim_value);
645     auto one = Cast(scope, Const(scope, {1}), indices.type());
646     auto dim_shape = Concat(
647         scope,
648         {Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value,
649          Output(Tile(scope, one,
650                      ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))},
651         /*axis=*/0);
652     batch_indices =
653         Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape));
654   }
655 
656   return batch_indices;
657 }
658 
BatchGatherGrad(const Scope & scope,Output params_shape,Output values,Output indices,int batch_dims,Output gather_dim_size)659 Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values,
660                        Output indices, int batch_dims, Output gather_dim_size) {
661   // Axis is the first non-batch dimension.
662   auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
663   Output outer_shape, flat_values_shape;
664   if (batch_dims != 0) {
665     auto values_shape = Shape(scope, values);
666     // Add the batch offsets to indices and flatten the batch dimensions.
667     outer_shape = Slice(scope, values_shape, {0}, {batch_dims});
668     auto inner_shape =
669         Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1});
670     auto batch_size = Prod(scope, outer_shape, /*axis=*/0);
671     flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0);
672     gather_dim_size = Multiply(scope, gather_dim_size, batch_size);
673     indices = GetBatchIndices(scope, params_shape, indices, batch_dims);
674     values = Reshape(scope, values, flat_values_shape);
675   }
676 
677   indices = Reshape(scope, indices, indices_size);
678   Output params_grad =
679       UnsortedSegmentSum(scope, values, indices, gather_dim_size);
680 
681   if (batch_dims != 0) {
682     // Put back the batch dimensions.
683     params_grad = Reshape(scope, params_grad, params_shape);
684   }
685   return params_grad;
686 }
687 
GatherV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)688 Status GatherV2Grad(const Scope& scope, const Operation& op,
689                     const std::vector<Output>& grad_inputs,
690                     std::vector<Output>* grad_outputs) {
691   if (op.num_inputs() != 3) {
692     return errors::InvalidArgument("Gather requires 3 inputs");
693   }
694   if (grad_inputs.size() != 1) {
695     return errors::InvalidArgument("Gather grad requires 1 grad input");
696   }
697 
698   // params can be large, so colocate the shape calculation with it.
699   // params can be very large for sparse model, array_ops.shape raises
700   // exception on the Windows platform when any dimension is larger than
701   // int32. params_shape is not used in optimizer apply_sparse gradients,
702   // so it's fine to convert it back to int32 regardless of truncation.
703   auto params = op.input(0);
704   auto colocate_scope = scope.ColocateWith(params);
705   Shape::Attrs shape_attrs;
706   shape_attrs.out_type_ = DT_INT64;
707   auto params_shape64 = Shape(colocate_scope, params, shape_attrs);
708   Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32);
709 
710   auto indices = op.input(1);
711   auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
712   auto axis = op.input(2);
713   auto axis_expand = ExpandDims(scope, axis, 0);
714 
715   int batch_dims;
716   TF_RETURN_IF_ERROR(
717       GetNodeAttr(op.node()->attrs(), "batch_dims", &batch_dims));
718   if (batch_dims < 0) {
719     // TODO(bdodson): Figure out if we can find the param rank here, like the
720     // python implementation does.
721     return errors::InvalidArgument(
722         "C++ GatherV2 gradient does not support negative batch_dims.");
723   }
724 
725   // Handle axis by transposing the axis dimension to be the first non-batch
726   // dimension, compute the gradient and transpose the result back.
727   auto outer_shape = Slice(scope, params_shape, {0}, axis_expand);
728   auto inner_shape =
729       Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1});
730   auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0);
731   auto values_dims = Size(scope, values_shape);
732   auto axis_dims = Size(scope, outer_shape);
733 
734   Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1);
735   Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1);
736   Output inner_axes_indices =
737       Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1);
738   Output axis_dims_expand = ExpandDims(scope, axis_dims, 0);
739 
740   auto values = Reshape(scope, grad_inputs[0], values_shape);
741 
742   // Move values[axis] up to values[batch_dims]
743   Output transpose_dims = Concat(scope,
744                                  {outer_batches_indices, axis_dims_expand,
745                                   batch_axis_indices, inner_axes_indices},
746                                  0);
747   auto values_transpose = Transpose(scope, values, transpose_dims);
748   Output gather_dim_size =
749       Squeeze(scope, Slice(scope, params_shape, axis_expand, {1}));
750   params_shape = Gather(scope, params_shape, transpose_dims);
751 
752   auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose,
753                                      indices, batch_dims, gather_dim_size);
754 
755   // Inverts the above transpose by moving dimension batch_dims back to its
756   // original position.
757   Output invert_transpose_dims = Concat(scope,
758                                         {outer_batches_indices,
759                                          Add(scope, batch_axis_indices, 1),
760                                          {batch_dims},
761                                          inner_axes_indices},
762                                         0);
763 
764   params_grad = Transpose(scope, params_grad, invert_transpose_dims);
765 
766   grad_outputs->push_back(params_grad);
767   grad_outputs->push_back(NoGradient());
768   grad_outputs->push_back(NoGradient());
769   return scope.status();
770 }
771 
772 REGISTER_GRADIENT_OP("GatherV2", GatherV2Grad);
773 
774 }  // anonymous namespace
775 }  // namespace ops
776 }  // namespace tensorflow
777