1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/ops/array_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23
24 namespace tensorflow {
25 namespace ops {
26 namespace {
27
28 REGISTER_NO_GRADIENT_OP("Const");
29 REGISTER_NO_GRADIENT_OP("StopGradient");
30 REGISTER_NO_GRADIENT_OP("ConcatOffset");
31 REGISTER_NO_GRADIENT_OP("EditDistance");
32 REGISTER_NO_GRADIENT_OP("ZerosLike");
33 REGISTER_NO_GRADIENT_OP("InvertPermutation");
34 REGISTER_NO_GRADIENT_OP("Shape");
35 REGISTER_NO_GRADIENT_OP("ShapeN");
36 REGISTER_NO_GRADIENT_OP("Rank");
37 REGISTER_NO_GRADIENT_OP("Size");
38 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39 REGISTER_NO_GRADIENT_OP("OneHot");
40
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)41 Status PackGrad(const Scope& scope, const Operation& op,
42 const std::vector<Output>& grad_inputs,
43 std::vector<Output>* grad_outputs) {
44 int N;
45 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46 int axis;
47 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48
49 grad_outputs->reserve(N);
50 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51 for (const Output& o : grad_op.output) {
52 grad_outputs->emplace_back(o);
53 }
54 return scope.status();
55 }
56 REGISTER_GRADIENT_OP("Pack", PackGrad);
57
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)58 Status UnpackGrad(const Scope& scope, const Operation& op,
59 const std::vector<Output>& grad_inputs,
60 std::vector<Output>* grad_outputs) {
61 int axis;
62 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64 return scope.status();
65 }
66 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)68 Status IdentityGrad(const Scope& scope, const Operation& op,
69 const std::vector<Output>& grad_inputs,
70 std::vector<Output>* grad_outputs) {
71 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72 return scope.status();
73 }
74 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status RefIdentityGrad(const Scope& scope, const Operation& op,
77 const std::vector<Output>& grad_inputs,
78 std::vector<Output>* grad_outputs) {
79 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80 return scope.status();
81 }
82 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)84 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85 const std::vector<Output>& grad_inputs,
86 std::vector<Output>* grad_outputs) {
87 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88 return scope.status();
89 }
90 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91
QuantizeAndDequantizeV4GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)92 Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93 const Operation& op,
94 const std::vector<Output>& grad_inputs,
95 std::vector<Output>* grad_outputs) {
96 Input input = Shape(scope, op.input(0));
97 Input input_min = op.input(1);
98 Input input_max = op.input(2);
99 int64_t axis;
100 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101 auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102 scope, grad_inputs[0], input, input_min, input_max,
103 QuantizeAndDequantizeV4Grad::Axis(axis));
104 grad_outputs->push_back(qdq_v4_grad.input_backprop);
105 grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106 grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107 return scope.status();
108 }
109 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110 QuantizeAndDequantizeV4GradHelper);
111
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)112 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113 const std::vector<Output>& grad_inputs,
114 std::vector<Output>* grad_outputs) {
115 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116 grad_outputs->push_back(NoGradient());
117 grad_outputs->push_back(NoGradient());
118 grad_outputs->push_back(NoGradient());
119 return scope.status();
120 }
121 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status SplitGrad(const Scope& scope, const Operation& op,
124 const std::vector<Output>& grad_inputs,
125 std::vector<Output>* grad_outputs) {
126 grad_outputs->push_back(NoGradient());
127 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128 return scope.status();
129 }
130 REGISTER_GRADIENT_OP("Split", SplitGrad);
131
SplitVGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)132 Status SplitVGrad(const Scope& scope, const Operation& op,
133 const std::vector<Output>& grad_inputs,
134 std::vector<Output>* grad_outputs) {
135 if (op.num_inputs() < 3) {
136 return errors::InvalidArgument("SplitV requires 3 arguments");
137 }
138 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2)));
139 for (int i = 0; i < op.num_inputs() - 1; ++i) {
140 grad_outputs->push_back(NoGradient());
141 }
142 return scope.status();
143 }
144 REGISTER_GRADIENT_OP("SplitV", SplitVGrad);
145
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)146 Status FillGrad(const Scope& scope, const Operation& op,
147 const std::vector<Output>& grad_inputs,
148 std::vector<Output>* grad_outputs) {
149 // y = fill(fill_shape, x)
150 // No gradient returned for the fill_shape argument.
151 grad_outputs->push_back(NoGradient());
152 // The gradient for x (which must be a scalar) is just the sum of
153 // all the gradients from the shape it fills.
154 // We use ReduceSum to implement this, which needs an argument providing
155 // the indices of all the dimensions of the incoming gradient.
156 // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
157 auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
158 Const(scope, 1));
159 grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
160 return scope.status();
161 }
162 REGISTER_GRADIENT_OP("Fill", FillGrad);
163
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)164 Status DiagGrad(const Scope& scope, const Operation& op,
165 const std::vector<Output>& grad_inputs,
166 std::vector<Output>* grad_outputs) {
167 grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
168 return scope.status();
169 }
170 REGISTER_GRADIENT_OP("Diag", DiagGrad);
171
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)172 Status DiagPartGrad(const Scope& scope, const Operation& op,
173 const std::vector<Output>& grad_inputs,
174 std::vector<Output>* grad_outputs) {
175 grad_outputs->push_back(Diag(scope, grad_inputs[0]));
176 return scope.status();
177 }
178 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
179
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)180 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
181 const std::vector<Output>& grad_inputs,
182 std::vector<Output>* grad_outputs) {
183 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
184 return scope.status();
185 }
186 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
187
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)188 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
189 const std::vector<Output>& grad_inputs,
190 std::vector<Output>* grad_outputs) {
191 auto num_lower = op.input(1);
192 auto num_upper = op.input(2);
193 grad_outputs->push_back(
194 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
195 grad_outputs->push_back(NoGradient());
196 grad_outputs->push_back(NoGradient());
197 return scope.status();
198 }
199 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
200
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)201 Status GatherNdGrad(const Scope& scope, const Operation& op,
202 const std::vector<Output>& grad_inputs,
203 std::vector<Output>* grad_outputs) {
204 auto ref = op.input(0);
205 auto ref_shape = Shape(scope, ref);
206 auto indices = op.input(1);
207 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
208 grad_outputs->push_back(NoGradient());
209 return scope.status();
210 }
211 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
212
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
214 const std::vector<Output>& grad_inputs,
215 std::vector<Output>* grad_outputs) {
216 string message;
217 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
218 string err_msg = strings::StrCat(
219 "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
220 message);
221 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
222 return scope.status();
223 }
224 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
225
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)226 Status ReshapeGrad(const Scope& scope, const Operation& op,
227 const std::vector<Output>& grad_inputs,
228 std::vector<Output>* grad_outputs) {
229 auto input_shape = Shape(scope, op.input(0));
230 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
231 grad_outputs->push_back(NoGradient());
232 return scope.status();
233 }
234 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
235
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)236 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
237 const std::vector<Output>& grad_inputs,
238 std::vector<Output>* grad_outputs) {
239 auto input_shape = Shape(scope, op.input(0));
240 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
241 grad_outputs->push_back(NoGradient());
242 return scope.status();
243 }
244 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
245
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)246 Status SqueezeGrad(const Scope& scope, const Operation& op,
247 const std::vector<Output>& grad_inputs,
248 std::vector<Output>* grad_outputs) {
249 auto input_shape = Shape(scope, op.input(0));
250 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
251 return scope.status();
252 }
253 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
254
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)255 Status TransposeGrad(const Scope& scope, const Operation& op,
256 const std::vector<Output>& grad_inputs,
257 std::vector<Output>* grad_outputs) {
258 auto inverted_perm = InvertPermutation(scope, op.input(1));
259 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
260 grad_outputs->push_back(NoGradient());
261 return scope.status();
262 }
263 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
264
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)265 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
266 const std::vector<Output>& grad_inputs,
267 std::vector<Output>* grad_outputs) {
268 auto seq_lengths = op.input(1);
269 int batch_dim;
270 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
271 int seq_dim;
272 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
273 grad_outputs->push_back(
274 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
275 ReverseSequence::BatchDim(batch_dim)));
276 grad_outputs->push_back(NoGradient());
277 return scope.status();
278 }
279 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
280
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)281 Status ReverseGrad(const Scope& scope, const Operation& op,
282 const std::vector<Output>& grad_inputs,
283 std::vector<Output>* grad_outputs) {
284 auto reverse_dims = op.input(1);
285 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
286 grad_outputs->push_back(NoGradient());
287 return scope.status();
288 }
289 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
290
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)291 Status ScatterNdGrad(const Scope& scope, const Operation& op,
292 const std::vector<Output>& grad_inputs,
293 std::vector<Output>* grad_outputs) {
294 auto indices = op.input(0);
295 grad_outputs->push_back(NoGradient());
296 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
297 grad_outputs->push_back(NoGradient());
298 return scope.status();
299 }
300 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
301
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)302 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
303 const std::vector<Output>& grad_inputs,
304 std::vector<Output>* grad_outputs) {
305 auto indices = op.input(1);
306 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
307 grad_outputs->push_back(NoGradient());
308 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
309 return scope.status();
310 }
311 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
312
313 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)314 Status PadGrad(const Scope& scope, const Operation& op,
315 const std::vector<Output>& grad_inputs,
316 std::vector<Output>* grad_outputs) {
317 auto x = op.input(0);
318 auto a = op.input(1); // [Rank(x), 2]
319 // Takes a slice of a. The 1st column. [Rank(x), 1].
320 auto size = Stack(scope, {Rank(scope, x), 1});
321 auto pad_before = Slice(scope, a, {0, 0}, size);
322 // Make it a 1-D tensor.
323 auto begin = Reshape(scope, pad_before, {-1});
324 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
325 grad_outputs->push_back(NoGradient());
326 // PadV2 adds a "constant_values" input.
327 if (IsPadV2) {
328 grad_outputs->push_back(NoGradient());
329 }
330 return scope.status();
331 }
332 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
333 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
334
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)335 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
336 const std::vector<Output>& grad_inputs,
337 std::vector<Output>* grad_outputs) {
338 int block_size;
339 TF_RETURN_IF_ERROR(
340 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
341 grad_outputs->push_back(
342 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
343 grad_outputs->push_back(NoGradient());
344 return scope.status();
345 }
346 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
347
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)348 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
349 const std::vector<Output>& grad_inputs,
350 std::vector<Output>* grad_outputs) {
351 grad_outputs->push_back(
352 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
353 grad_outputs->push_back(NoGradient());
354 grad_outputs->push_back(NoGradient());
355 return scope.status();
356 }
357 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
358
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
360 const std::vector<Output>& grad_inputs,
361 std::vector<Output>* grad_outputs) {
362 int block_size;
363 TF_RETURN_IF_ERROR(
364 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
365 grad_outputs->push_back(
366 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
367 grad_outputs->push_back(NoGradient());
368 return scope.status();
369 }
370 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
371
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)372 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
373 const std::vector<Output>& grad_inputs,
374 std::vector<Output>* grad_outputs) {
375 grad_outputs->push_back(
376 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
377 grad_outputs->push_back(NoGradient());
378 grad_outputs->push_back(NoGradient());
379 return scope.status();
380 }
381 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
382
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)383 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
384 const std::vector<Output>& grad_inputs,
385 std::vector<Output>* grad_outputs) {
386 int block_size;
387 TF_RETURN_IF_ERROR(
388 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
389 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
390 return scope.status();
391 }
392 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
393
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)394 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
395 const std::vector<Output>& grad_inputs,
396 std::vector<Output>* grad_outputs) {
397 int block_size;
398 TF_RETURN_IF_ERROR(
399 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
400 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
401 return scope.status();
402 }
403 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
404
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)405 Status MirrorPadGrad(const Scope& scope, const Operation& op,
406 const std::vector<Output>& grad_inputs,
407 std::vector<Output>* grad_outputs) {
408 string mode;
409 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
410 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
411 scope, grad_inputs[0], op.input(1), mode));
412 grad_outputs->push_back(NoGradient());
413 return scope.status();
414 }
415 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
416
417 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)418 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
419 const std::vector<Output>& grad_inputs,
420 std::vector<Output>* grad_outputs) {
421 string mode;
422 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
423 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
424 grad_outputs->push_back(NoGradient());
425 return scope.status();
426 }
427 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
428
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
430 const std::vector<Output>& grad_inputs,
431 std::vector<Output>* grad_outputs) {
432 Input x = Shape(scope, op.input(0));
433 Input begin = op.input(1);
434 Input end = op.input(2);
435 Input strides = op.input(3);
436 int64_t begin_mask;
437 int64_t end_mask;
438 int64_t ellipsis_mask;
439 int64_t new_axis_mask;
440 int64_t shrink_axis_mask;
441 TF_RETURN_IF_ERROR(
442 GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
443 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
444 TF_RETURN_IF_ERROR(
445 GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
446 TF_RETURN_IF_ERROR(
447 GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
448 TF_RETURN_IF_ERROR(
449 GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
450 grad_outputs->push_back(
451 StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
452 StridedSliceGrad::BeginMask(begin_mask)
453 .EndMask(end_mask)
454 .EllipsisMask(ellipsis_mask)
455 .NewAxisMask(new_axis_mask)
456 .ShrinkAxisMask(shrink_axis_mask)));
457 // No gradients returned for begin, end and strides
458 grad_outputs->push_back(NoGradient());
459 grad_outputs->push_back(NoGradient());
460 grad_outputs->push_back(NoGradient());
461 return scope.status();
462 }
463 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
464
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)465 Status SliceGrad(const Scope& scope, const Operation& op,
466 const std::vector<Output>& grad_inputs,
467 std::vector<Output>* grad_outputs) {
468 // Propagate the incoming gradient along all the selected values,
469 // and zero everywhere else. Use the Pad operator for this.
470 //
471 // First create an Nx2 padding where N is the number of input
472 // dimensions. The first column is the number of prepended zeros
473 // for each dimension, and the second column is the number of
474 // appended zeros.
475 //
476 // The first column is just the begin vector.
477 // The second column is the shape of the input element-wise
478 // subtracted by begin+size
479
480 // Running example:
481 // input.shape = [3, 5, 3]
482 // begin = [1, 2, 1], size = [1, 3, 2]
483 Input input = op.input(0);
484 Input begin = op.input(1);
485 // input_rank = 3
486 auto input_rank = Rank(scope, input);
487 // slice_size = [1, 3, 2]
488 auto slice_size = Shape(scope, op.output(0));
489 // padding_shape = [3, 1]
490 auto padding_shape = Stack(scope, {input_rank, 1});
491 // before_padding = [[1]
492 // [2]
493 // [1]]
494 Input before_padding = Reshape(scope, begin, padding_shape);
495 // after_padding_sizes = shape(input) - slice_size - begin
496 // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
497 // = [1, 0, 0]
498 auto after_padding_sizes =
499 Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
500 // after_padding = [[1]
501 // [0]
502 // [0]]
503 Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
504 // paddings = [[1 1]
505 // [2 0]
506 // [1 0]]
507 auto paddings =
508 Concat(scope, {before_padding, after_padding}, Const(scope, 1));
509 grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
510 // Nothing propagated for "begin" and "size" inputs
511 grad_outputs->push_back(NoGradient());
512 grad_outputs->push_back(NoGradient());
513 return scope.status();
514 }
515 REGISTER_GRADIENT_OP("Slice", SliceGrad);
516
ConcatGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,int start_value_index,int end_value_index,int dim_index)517 Status ConcatGradHelper(const Scope& scope, const Operation& op,
518 const std::vector<Output>& grad_inputs,
519 std::vector<Output>* grad_outputs,
520 int start_value_index, int end_value_index,
521 int dim_index) {
522 if (end_value_index >= op.num_inputs()) {
523 return errors::Internal("Invalid input index");
524 }
525 std::vector<Output> inputs;
526 inputs.reserve(end_value_index - start_value_index);
527 for (int i = start_value_index; i < end_value_index; ++i) {
528 inputs.push_back(op.input(i));
529 }
530
531 auto shapes = ShapeN(scope, inputs);
532 const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset");
533 auto builder =
534 ::tensorflow::NodeBuilder(unique_name, "ConcatOffset")
535 .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
536 .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
537 scope.UpdateBuilder(&builder);
538 ::tensorflow::Node* concat_offset_node;
539 scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node));
540 scope.UpdateStatus(scope.DoShapeInference(concat_offset_node));
541 if (concat_offset_node->num_outputs() != inputs.size()) {
542 return errors::Internal("ConcatOffset has invalid output count");
543 }
544 if (grad_inputs.size() != 1) {
545 return errors::InvalidArgument("Concat grad should have 1 input");
546 }
547
548 // For each dx[i], we take a slice of dy. The offset and size of the
549 // slice is given by offset[i] and shape[i].
550 const Output& dy = grad_inputs[0];
551 for (int i = 0; i < inputs.size(); ++i) {
552 grad_outputs->push_back(
553 Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i]));
554 }
555
556 // Insert a NoGradient for the axis.
557 grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient());
558 return scope.status();
559 }
560
ConcatV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)561 Status ConcatV2Grad(const Scope& scope, const Operation& op,
562 const std::vector<Output>& grad_inputs,
563 std::vector<Output>* grad_outputs) {
564 return ConcatGradHelper(scope, op, grad_inputs, grad_outputs,
565 /*start_value_index=*/0,
566 /*end_value_index=*/op.num_inputs() - 1,
567 /*dim+index=*/op.num_inputs() - 1);
568 }
569
570 REGISTER_GRADIENT_OP("ConcatV2", ConcatV2Grad);
571
BroadcastToGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)572 Status BroadcastToGrad(const Scope& scope, const Operation& op,
573 const std::vector<Output>& grad_inputs,
574 std::vector<Output>* grad_outputs) {
575 if (grad_inputs.size() != 1) {
576 return errors::InvalidArgument("BroadcastTo grad should have 1 grad input");
577 }
578 if (op.num_inputs() != 2) {
579 return errors::InvalidArgument("BroadcastTo requires 2 inputs");
580 }
581
582 auto x_shape = Shape(scope, op.input(0));
583 auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1));
584 auto sum_gx = Sum(scope, grad_inputs[0], args.r0);
585 grad_outputs->push_back(Reshape(scope, sum_gx, x_shape));
586 grad_outputs->push_back(NoGradient());
587 return scope.status();
588 }
589
590 REGISTER_GRADIENT_OP("BroadcastTo", BroadcastToGrad);
591
TileGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)592 Status TileGrad(const Scope& scope, const Operation& op,
593 const std::vector<Output>& grad_inputs,
594 std::vector<Output>* grad_outputs) {
595 if (op.num_inputs() != 2) {
596 return errors::InvalidArgument("Tile requires 2 inputs");
597 }
598 if (grad_inputs.size() != 1) {
599 return errors::InvalidArgument("Tile grad requires 1 grad input");
600 }
601
602 Shape::Attrs shape_attrs;
603 shape_attrs.out_type_ = op.input_type(1);
604 auto input_shape = Shape(scope, op.input(0), shape_attrs);
605 // We interleave multiples and input_shape to get split_shape,
606 // reshape grad to split_shape, and reduce along all even
607 // dimensions (the tiled dimensions) to get the result
608 // with shape input_shape. For example
609 // input_shape = [20, 30, 40]
610 // multiples = [2, 3, 4]
611 // split_shape = [2, 20, 3, 30, 4, 40]
612 // axes = [0, 2, 4]
613 auto stack = Stack(scope, {op.input(1), input_shape.output});
614 auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1);
615 auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1});
616 auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2);
617 auto input_grad = ReduceSum(
618 scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output);
619 grad_outputs->push_back(input_grad.output);
620 grad_outputs->push_back(NoGradient());
621 return scope.status();
622 }
623 REGISTER_GRADIENT_OP("Tile", TileGrad);
624
625 // Create a constant of the provided d_type;
ConstHelper(const Scope & scope,int value,DataType d_type)626 Output ConstHelper(const Scope& scope, int value, DataType d_type) {
627 return Cast(scope, Const(scope, value), d_type);
628 }
629
630 // Adds the batch offsets to the given indices and returns the results.
GetBatchIndices(const Scope & scope,const Output & params_shape,const Output & indices,int batch_dims)631 Output GetBatchIndices(const Scope& scope, const Output& params_shape,
632 const Output& indices, int batch_dims) {
633 Output batch_indices = indices;
634 auto indices_ndims = Rank(scope, indices);
635 auto casted_params_shape = Cast(scope, params_shape, indices.type());
636 Output accum_dim_value = ConstHelper(scope, 1, indices.type());
637 for (int dim = batch_dims; dim > 0; dim--) {
638 Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1});
639 accum_dim_value = Multiply(scope, accum_dim_value,
640 Slice(scope, casted_params_shape, {dim}, {1}));
641 auto start = ConstHelper(scope, 0, indices.type());
642 auto step = ConstHelper(scope, 1, indices.type());
643 Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step);
644 dim_indices = Multiply(scope, dim_indices, accum_dim_value);
645 auto one = Cast(scope, Const(scope, {1}), indices.type());
646 auto dim_shape = Concat(
647 scope,
648 {Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value,
649 Output(Tile(scope, one,
650 ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))},
651 /*axis=*/0);
652 batch_indices =
653 Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape));
654 }
655
656 return batch_indices;
657 }
658
BatchGatherGrad(const Scope & scope,Output params_shape,Output values,Output indices,int batch_dims,Output gather_dim_size)659 Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values,
660 Output indices, int batch_dims, Output gather_dim_size) {
661 // Axis is the first non-batch dimension.
662 auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
663 Output outer_shape, flat_values_shape;
664 if (batch_dims != 0) {
665 auto values_shape = Shape(scope, values);
666 // Add the batch offsets to indices and flatten the batch dimensions.
667 outer_shape = Slice(scope, values_shape, {0}, {batch_dims});
668 auto inner_shape =
669 Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1});
670 auto batch_size = Prod(scope, outer_shape, /*axis=*/0);
671 flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0);
672 gather_dim_size = Multiply(scope, gather_dim_size, batch_size);
673 indices = GetBatchIndices(scope, params_shape, indices, batch_dims);
674 values = Reshape(scope, values, flat_values_shape);
675 }
676
677 indices = Reshape(scope, indices, indices_size);
678 Output params_grad =
679 UnsortedSegmentSum(scope, values, indices, gather_dim_size);
680
681 if (batch_dims != 0) {
682 // Put back the batch dimensions.
683 params_grad = Reshape(scope, params_grad, params_shape);
684 }
685 return params_grad;
686 }
687
GatherV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)688 Status GatherV2Grad(const Scope& scope, const Operation& op,
689 const std::vector<Output>& grad_inputs,
690 std::vector<Output>* grad_outputs) {
691 if (op.num_inputs() != 3) {
692 return errors::InvalidArgument("Gather requires 3 inputs");
693 }
694 if (grad_inputs.size() != 1) {
695 return errors::InvalidArgument("Gather grad requires 1 grad input");
696 }
697
698 // params can be large, so colocate the shape calculation with it.
699 // params can be very large for sparse model, array_ops.shape raises
700 // exception on the Windows platform when any dimension is larger than
701 // int32. params_shape is not used in optimizer apply_sparse gradients,
702 // so it's fine to convert it back to int32 regardless of truncation.
703 auto params = op.input(0);
704 auto colocate_scope = scope.ColocateWith(params);
705 Shape::Attrs shape_attrs;
706 shape_attrs.out_type_ = DT_INT64;
707 auto params_shape64 = Shape(colocate_scope, params, shape_attrs);
708 Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32);
709
710 auto indices = op.input(1);
711 auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
712 auto axis = op.input(2);
713 auto axis_expand = ExpandDims(scope, axis, 0);
714
715 int batch_dims;
716 TF_RETURN_IF_ERROR(
717 GetNodeAttr(op.node()->attrs(), "batch_dims", &batch_dims));
718 if (batch_dims < 0) {
719 // TODO(bdodson): Figure out if we can find the param rank here, like the
720 // python implementation does.
721 return errors::InvalidArgument(
722 "C++ GatherV2 gradient does not support negative batch_dims.");
723 }
724
725 // Handle axis by transposing the axis dimension to be the first non-batch
726 // dimension, compute the gradient and transpose the result back.
727 auto outer_shape = Slice(scope, params_shape, {0}, axis_expand);
728 auto inner_shape =
729 Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1});
730 auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0);
731 auto values_dims = Size(scope, values_shape);
732 auto axis_dims = Size(scope, outer_shape);
733
734 Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1);
735 Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1);
736 Output inner_axes_indices =
737 Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1);
738 Output axis_dims_expand = ExpandDims(scope, axis_dims, 0);
739
740 auto values = Reshape(scope, grad_inputs[0], values_shape);
741
742 // Move values[axis] up to values[batch_dims]
743 Output transpose_dims = Concat(scope,
744 {outer_batches_indices, axis_dims_expand,
745 batch_axis_indices, inner_axes_indices},
746 0);
747 auto values_transpose = Transpose(scope, values, transpose_dims);
748 Output gather_dim_size =
749 Squeeze(scope, Slice(scope, params_shape, axis_expand, {1}));
750 params_shape = Gather(scope, params_shape, transpose_dims);
751
752 auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose,
753 indices, batch_dims, gather_dim_size);
754
755 // Inverts the above transpose by moving dimension batch_dims back to its
756 // original position.
757 Output invert_transpose_dims = Concat(scope,
758 {outer_batches_indices,
759 Add(scope, batch_axis_indices, 1),
760 {batch_dims},
761 inner_axes_indices},
762 0);
763
764 params_grad = Transpose(scope, params_grad, invert_transpose_dims);
765
766 grad_outputs->push_back(params_grad);
767 grad_outputs->push_back(NoGradient());
768 grad_outputs->push_back(NoGradient());
769 return scope.status();
770 }
771
772 REGISTER_GRADIENT_OP("GatherV2", GatherV2Grad);
773
774 } // anonymous namespace
775 } // namespace ops
776 } // namespace tensorflow
777