xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/nn_grad.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in nn_ops.py."""
16
17import functools
18import itertools
19import operator
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_nn_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import nn_ops
28
29
30@ops.RegisterGradient("Conv2DBackpropInput")
31def _Conv2DBackpropInputGrad(op, grad):
32  """The derivatives for deconvolution.
33
34  Args:
35    op: the Deconvolution op.
36    grad: the tensor representing the gradient w.r.t. the output
37
38  Returns:
39    the gradients w.r.t. the input and the filter
40  """
41  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
42  # functions for performance reasons in Eager mode. See _Conv2DGrad.
43  return [
44      None,
45      gen_nn_ops.conv2d_backprop_filter(
46          grad,
47          array_ops.shape(op.inputs[1]),
48          op.inputs[2],
49          dilations=op.get_attr("dilations"),
50          strides=op.get_attr("strides"),
51          padding=op.get_attr("padding"),
52          explicit_paddings=op.get_attr("explicit_paddings"),
53          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
54          data_format=op.get_attr("data_format").decode()),
55      gen_nn_ops.conv2d(
56          grad,
57          op.inputs[1],
58          dilations=op.get_attr("dilations"),
59          strides=op.get_attr("strides"),
60          padding=op.get_attr("padding"),
61          explicit_paddings=op.get_attr("explicit_paddings"),
62          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
63          data_format=op.get_attr("data_format").decode())
64  ]
65
66
67@ops.RegisterGradient("Conv2DBackpropFilter")
68def _Conv2DBackpropFilterGrad(op, grad):
69  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
70  # functions for performance reasons in Eager mode. See _Conv2DGrad.
71  return [
72      gen_nn_ops.conv2d_backprop_input(
73          array_ops.shape(op.inputs[0]),
74          grad,
75          op.inputs[2],
76          dilations=op.get_attr("dilations"),
77          strides=op.get_attr("strides"),
78          padding=op.get_attr("padding"),
79          explicit_paddings=op.get_attr("explicit_paddings"),
80          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
81          data_format=op.get_attr("data_format").decode()), None,
82      gen_nn_ops.conv2d(
83          op.inputs[0],
84          grad,
85          dilations=op.get_attr("dilations"),
86          strides=op.get_attr("strides"),
87          padding=op.get_attr("padding"),
88          explicit_paddings=op.get_attr("explicit_paddings"),
89          use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
90          data_format=op.get_attr("data_format").decode())
91  ]
92
93
94@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput")
95def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
96  """The derivatives for deconvolution.
97
98  Args:
99    op: the Deconvolution op.
100    grad: the tensor representing the gradient w.r.t. the output
101
102  Returns:
103    the gradients w.r.t. the input and the filter
104  """
105  return [
106      None,
107      gen_nn_ops.depthwise_conv2d_native_backprop_filter(
108          grad,
109          array_ops.shape(op.inputs[1]),
110          op.inputs[2],
111          dilations=op.get_attr("dilations"),
112          strides=op.get_attr("strides"),
113          padding=op.get_attr("padding"),
114          explicit_paddings=op.get_attr("explicit_paddings"),
115          data_format=op.get_attr("data_format")),
116      gen_nn_ops.depthwise_conv2d_native(
117          grad,
118          op.inputs[1],
119          dilations=op.get_attr("dilations"),
120          strides=op.get_attr("strides"),
121          padding=op.get_attr("padding"),
122          explicit_paddings=op.get_attr("explicit_paddings"),
123          data_format=op.get_attr("data_format"))
124  ]
125
126
127@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
128def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
129  return [
130      gen_nn_ops.depthwise_conv2d_native_backprop_input(
131          array_ops.shape(op.inputs[0]),
132          grad,
133          op.inputs[2],
134          dilations=op.get_attr("dilations"),
135          strides=op.get_attr("strides"),
136          padding=op.get_attr("padding"),
137          explicit_paddings=op.get_attr("explicit_paddings"),
138          data_format=op.get_attr("data_format")), None,
139      gen_nn_ops.depthwise_conv2d_native(
140          op.inputs[0],
141          grad,
142          dilations=op.get_attr("dilations"),
143          strides=op.get_attr("strides"),
144          padding=op.get_attr("padding"),
145          explicit_paddings=op.get_attr("explicit_paddings"),
146          data_format=op.get_attr("data_format"))
147  ]
148
149
150@ops.RegisterGradient("Conv3D")
151def _Conv3DGrad(op, grad):
152  data_format = op.get_attr("data_format").decode()
153  return [
154      nn_ops.conv3d_backprop_input_v2(
155          array_ops.shape(op.inputs[0]),
156          op.inputs[1],
157          grad,
158          dilations=op.get_attr("dilations"),
159          strides=op.get_attr("strides"),
160          padding=op.get_attr("padding"),
161          data_format=data_format),
162      nn_ops.conv3d_backprop_filter_v2(
163          op.inputs[0],
164          array_ops.shape(op.inputs[1]),
165          grad,
166          dilations=op.get_attr("dilations"),
167          strides=op.get_attr("strides"),
168          padding=op.get_attr("padding"),
169          data_format=data_format)
170  ]
171
172
173@ops.RegisterGradient("Conv3DBackpropInputV2")
174def _Conv3DBackpropInputGrad(op, grad):
175  data_format = op.get_attr("data_format").decode()
176  return [
177      None,
178      nn_ops.conv3d_backprop_filter_v2(
179          grad,
180          array_ops.shape(op.inputs[1]),
181          op.inputs[2],
182          dilations=op.get_attr("dilations"),
183          strides=op.get_attr("strides"),
184          padding=op.get_attr("padding"),
185          data_format=data_format),
186      nn_ops.conv3d(
187          grad,
188          op.inputs[1],
189          dilations=op.get_attr("dilations"),
190          strides=op.get_attr("strides"),
191          padding=op.get_attr("padding"),
192          data_format=data_format)
193  ]
194
195
196@ops.RegisterGradient("Conv3DBackpropFilterV2")
197def _Conv3DBackpropFilterGrad(op, grad):
198  data_format = op.get_attr("data_format").decode()
199  return [
200      nn_ops.conv3d_backprop_input_v2(
201          array_ops.shape(op.inputs[0]),
202          grad,
203          op.inputs[2],
204          dilations=op.get_attr("dilations"),
205          strides=op.get_attr("strides"),
206          padding=op.get_attr("padding"),
207          data_format=data_format), None,
208      nn_ops.conv3d(
209          op.inputs[0],
210          grad,
211          dilations=op.get_attr("dilations"),
212          strides=op.get_attr("strides"),
213          padding=op.get_attr("padding"),
214          data_format=data_format)
215  ]
216
217
218@ops.RegisterGradient("AvgPool3D")
219def _AvgPool3DGrad(op, grad):
220  return gen_nn_ops.avg_pool3d_grad(
221      array_ops.shape(op.inputs[0]),
222      grad,
223      ksize=op.get_attr("ksize"),
224      strides=op.get_attr("strides"),
225      padding=op.get_attr("padding"),
226      data_format=op.get_attr("data_format").decode())
227
228
229@ops.RegisterGradient("AvgPool3DGrad")
230def _AvgPool3DGradGrad(op, grad):
231  return (array_ops.stop_gradient(op.inputs[0]),
232          gen_nn_ops.avg_pool3d(
233              grad,
234              op.get_attr("ksize"),
235              op.get_attr("strides"),
236              op.get_attr("padding"),
237              data_format=op.get_attr("data_format").decode()))
238
239
240@ops.RegisterGradient("MaxPool3D")
241def _MaxPool3DGrad(op, grad):
242  return gen_nn_ops.max_pool3d_grad(
243      op.inputs[0],
244      op.outputs[0],
245      grad,
246      ksize=op.get_attr("ksize"),
247      strides=op.get_attr("strides"),
248      padding=op.get_attr("padding"),
249      data_format=op.get_attr("data_format").decode())
250
251
252@ops.RegisterGradient("MaxPool3DGrad")
253def _MaxPool3DGradGrad(op, grad):
254  return (array_ops.zeros_like(op.inputs[0]),
255          array_ops.zeros_like(op.inputs[1]),
256          gen_nn_ops.max_pool3d_grad_grad(
257              op.inputs[0],
258              op.inputs[1],
259              grad,
260              op.get_attr("ksize"),
261              op.get_attr("strides"),
262              padding=op.get_attr("padding"),
263              data_format=op.get_attr("data_format").decode()))
264
265
266@ops.RegisterGradient("MaxPool3DGradGrad")
267def _MaxPool3DGradGradGrad(op, grad):
268  return (array_ops.zeros_like(op.inputs[0]),
269          array_ops.zeros_like(op.inputs[1]),
270          gen_nn_ops.max_pool3d_grad(
271              op.inputs[0],
272              op.inputs[1],
273              grad,
274              op.get_attr("ksize"),
275              op.get_attr("strides"),
276              padding=op.get_attr("padding"),
277              data_format=op.get_attr("data_format").decode()))
278
279
280@ops.RegisterGradient("Softmax")
281def _SoftmaxGrad(op, grad_softmax):
282  """The derivative of the softmax nonlinearity.
283
284  We assume that probs is of shape [batch_size * dim]
285  The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax').
286  This matrix is diagonal minus a rank one matrix, so it is easy to implement
287  as follows:
288
289    grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax
290
291  Args:
292     op: the Softmax op.
293     grad_softmax:  the tensor representing the gradient w.r.t. the softmax
294       output.
295
296  Returns:
297     gradient w.r.t the input to the softmax
298
299  """
300  softmax = op.outputs[0]
301  sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True)
302  return (grad_softmax - sum_channels) * softmax
303
304
305@ops.RegisterGradient("LogSoftmax")
306def _LogSoftmaxGrad(op, grad):
307  """The gradient for log_softmax.
308
309      log_softmax = input - log(sum(exp(input))
310      dlog_softmax/dinput = diag - softmax(input)
311
312  Args:
313    op: The log softmax op.
314    grad: The tensor representing the gradient w.r.t. the output.
315
316  Returns:
317    The gradients w.r.t. the input.
318  """
319  softmax = math_ops.exp(op.outputs[0])
320  return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax
321
322
323@ops.RegisterGradient("BiasAdd")
324def _BiasAddGrad(op, received_grad):
325  """Return the gradients for the 2 inputs of bias_op.
326
327  The first input of unused_bias_op is the tensor t, and its gradient is
328  just the gradient the unused_bias_op received.
329
330  The second input of unused_bias_op is the bias vector which has one fewer
331  dimension than "received_grad" (the batch dimension.)  Its gradient is the
332  received gradient Summed on the batch dimension, which is the first dimension.
333
334  Args:
335    op: The BiasOp for which we need to generate gradients.
336    received_grad: Tensor.  The gradients passed to the BiasOp.
337
338  Returns:
339    Two tensors, the first one for the "tensor" input of the BiasOp,
340    the second one for the "bias" input of the BiasOp.
341  """
342  try:
343    data_format = op.get_attr("data_format")
344  except ValueError:
345    data_format = None
346  return (received_grad,
347          gen_nn_ops.bias_add_grad(
348              out_backprop=received_grad, data_format=data_format))
349
350
351@ops.RegisterGradient("BiasAddGrad")
352def _BiasAddGradGrad(op, received_grad):
353  """Gradient for the BiasAddGrad op.
354
355  Args:
356    op: BiasAddGrad op for which we are calculating gradients.
357    received_grad: The gradients passed to the BiasAddGrad op.
358
359  Returns:
360    A single gradient Tensor for the input to BiasAddGrad (which
361    is the gradient of the bias term in BiasAdd)
362  """
363
364  try:
365    data_format = op.get_attr("data_format")
366  except ValueError:
367    data_format = None
368
369  shape = array_ops.shape(op.inputs[0])
370  bias_shape = array_ops.shape(received_grad)
371
372  if data_format == b"NCHW":
373    expanded_shape = array_ops.concat([
374        array_ops.ones_like(shape[:1]), bias_shape,
375        array_ops.ones_like(shape[2:])
376    ], 0)
377    tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
378  else:
379    expanded_shape = array_ops.concat(
380        [array_ops.ones_like(shape[:-1]), bias_shape], 0)
381    tile_mults = array_ops.concat([shape[:-1], [1]], 0)
382
383  expanded_grad = array_ops.reshape(received_grad, expanded_shape)
384  return array_ops.tile(expanded_grad, tile_mults)
385
386
387@ops.RegisterGradient("BiasAddV1")
388def _BiasAddGradV1(unused_bias_op, received_grad):
389  """Return the gradients for the 2 inputs of bias_op.
390
391  The first input of unused_bias_op is the tensor t, and its gradient is
392  just the gradient the unused_bias_op received.
393
394  The second input of unused_bias_op is the bias vector which has one fewer
395  dimension than "received_grad" (the batch dimension.)  Its gradient is the
396  received gradient Summed on the batch dimension, which is the first dimension.
397
398  Args:
399    unused_bias_op: The BiasOp for which we need to generate gradients.
400    received_grad: Tensor.  The gradients passed to the BiasOp.
401
402  Returns:
403    Two tensors, the first one for the "tensor" input of the BiasOp,
404    the second one for the "bias" input of the BiasOp.
405  """
406  reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
407  return (received_grad, math_ops.reduce_sum(received_grad,
408                                             reduction_dim_tensor))
409
410
411@ops.RegisterGradient("Relu")
412def _ReluGrad(op, grad):
413  return gen_nn_ops.relu_grad(grad, op.outputs[0])
414
415
416@ops.RegisterGradient("EluGrad")
417def _EluGradGrad(op, grad):
418  elu_x = op.inputs[1]
419  return (gen_nn_ops.elu_grad(grad, elu_x),
420          array_ops.where(
421              elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x)))
422
423
424@ops.RegisterGradient("SeluGrad")
425def _SeluGradGrad(op, grad):
426  selu_x = op.inputs[1]
427  return (gen_nn_ops.selu_grad(grad, selu_x),
428          array_ops.where(
429              selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x)))
430
431
432@ops.RegisterGradient("Relu6")
433def _Relu6Grad(op, grad):
434  return gen_nn_ops.relu6_grad(grad, op.outputs[0])
435
436
437@ops.RegisterGradient("Relu6Grad")
438def _Relu6GradGrad(op, grad):
439  x = op.inputs[1]
440  return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x))
441
442
443@ops.RegisterGradient("LeakyRelu")
444def _LeakyReluGrad(op, grad):
445  x = op.inputs[0]
446  alpha = op.get_attr("alpha")
447  return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
448
449
450@ops.RegisterGradient("LeakyReluGrad")
451def _LeakyReluGradGrad(op, grad):
452  x = op.inputs[1]
453  alpha = op.get_attr("alpha")
454  return (gen_nn_ops.leaky_relu_grad(grad, x,
455                                     alpha=alpha), array_ops.zeros_like(x))
456
457
458@ops.RegisterGradient("Elu")
459def _EluGrad(op, grad):
460  return gen_nn_ops.elu_grad(grad, op.outputs[0])
461
462
463@ops.RegisterGradient("Selu")
464def _SeluGrad(op, grad):
465  return gen_nn_ops.selu_grad(grad, op.outputs[0])
466
467
468@ops.RegisterGradient("Softplus")
469def _SoftplusGrad(op, grad):
470  return grad * math_ops.sigmoid(op.inputs[0])
471
472
473@ops.RegisterGradient("SoftplusGrad")
474def _SoftplusGradGrad(op, grad):
475  # Let:
476  #   y = tf.nn.softplus(x)
477  #   dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x))
478  # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
479  dy, x = op.inputs
480  with ops.control_dependencies([grad]):
481    ddy = gen_nn_ops.softplus_grad(grad, x)
482    d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
483    return (ddy, d2x)
484
485
486@ops.RegisterGradient("Softsign")
487def _SoftsignGrad(op, grad):
488  return gen_nn_ops.softsign_grad(grad, op.inputs[0])
489
490
491@ops.RegisterGradient("ReluGrad")
492def _ReluGradGrad(op, grad):
493  x = op.inputs[1]
494  return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x))
495
496
497def _BroadcastMul(vec, mat):
498  """Multiply after broadcasting vec to match dimensions of mat.
499
500  Args:
501    vec: A 1-D tensor of dimension [D0]
502    mat: A 2-D tensor of dimension [D0, D1]
503
504  Returns:
505    A tensor of dimension [D0, D1], the result of vec * mat
506  """
507  # Reshape vec to [D0, 1]
508  vec = array_ops.expand_dims(vec, -1)
509  return vec * mat
510
511
512@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits")
513def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
514  """Gradient function for SoftmaxCrossEntropyWithLogits."""
515  # grad_loss is the backprop for cost, and we multiply it with the gradients
516  # (which is output[1])
517  # grad_grad is the backprop for softmax gradient.
518  #
519  # Second derivative is just softmax derivative w.r.t. logits.
520  softmax_grad = op.outputs[1]
521  grad = _BroadcastMul(grad_loss, softmax_grad)
522
523  logits = op.inputs[0]
524  if (grad_grad is not None and
525      not getattr(grad_grad, "_is_zeros_tensor", False)):
526    softmax = nn_ops.softmax(logits)
527
528    grad += ((grad_grad - array_ops.squeeze(
529        math_ops.matmul(
530            array_ops.expand_dims(grad_grad, 1),
531            array_ops.expand_dims(softmax, 2)),
532        axis=1)) * softmax)
533
534  return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))  # pylint: disable=invalid-unary-operand-type
535
536
537@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
538def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
539  """Gradient function for SparseSoftmaxCrossEntropyWithLogits."""
540  # grad_loss is the backprop for cost, and we multiply it with the gradients
541  # (which is output[1])
542  # grad_grad is the backprop for softmax gradient.
543  # There is no gradient for the labels
544  #
545  # Second derivative is just softmax derivative w.r.t. logits.
546  softmax_grad = op.outputs[1]
547  grad = _BroadcastMul(grad_loss, softmax_grad)
548
549  logits = op.inputs[0]
550  if (grad_grad is not None and
551      not getattr(grad_grad, "_is_zeros_tensor", False)):
552    softmax = nn_ops.softmax(logits)
553
554    grad += ((grad_grad - array_ops.squeeze(
555        math_ops.matmul(
556            array_ops.expand_dims(grad_grad, 1),
557            array_ops.expand_dims(softmax, 2)),
558        axis=1)) * softmax)
559
560  return grad, None
561
562
563@ops.RegisterGradient("Conv2D")
564def _Conv2DGrad(op, grad):
565  """Gradient function for Conv2D."""
566  dilations = op.get_attr("dilations")
567  strides = op.get_attr("strides")
568  padding = op.get_attr("padding")
569  explicit_paddings = op.get_attr("explicit_paddings")
570  use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
571  data_format = op.get_attr("data_format")
572  shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
573
574  # We call the gen_nn_ops backprop functions instead of nn_ops backprop
575  # functions for performance reasons in Eager mode. gen_nn_ops functions take a
576  # `explicit_paddings` parameter, but nn_ops functions do not. So if we were
577  # to use the nn_ops functions, we would have to convert `padding` and
578  # `explicit_paddings` into a single `padding` parameter, increasing overhead
579  # in Eager mode.
580  return [
581      gen_nn_ops.conv2d_backprop_input(
582          shape_0,
583          op.inputs[1],
584          grad,
585          dilations=dilations,
586          strides=strides,
587          padding=padding,
588          explicit_paddings=explicit_paddings,
589          use_cudnn_on_gpu=use_cudnn_on_gpu,
590          data_format=data_format),
591      gen_nn_ops.conv2d_backprop_filter(
592          op.inputs[0],
593          shape_1,
594          grad,
595          dilations=dilations,
596          strides=strides,
597          padding=padding,
598          explicit_paddings=explicit_paddings,
599          use_cudnn_on_gpu=use_cudnn_on_gpu,
600          data_format=data_format)
601  ]
602
603
604@ops.RegisterGradient("DepthwiseConv2dNative")
605def _DepthwiseConv2dNativeGrad(op, grad):
606  return [
607      gen_nn_ops.depthwise_conv2d_native_backprop_input(
608          array_ops.shape(op.inputs[0]),
609          op.inputs[1],
610          grad,
611          dilations=op.get_attr("dilations"),
612          strides=op.get_attr("strides"),
613          padding=op.get_attr("padding"),
614          explicit_paddings=op.get_attr("explicit_paddings"),
615          data_format=op.get_attr("data_format")),
616      gen_nn_ops.depthwise_conv2d_native_backprop_filter(
617          op.inputs[0],
618          array_ops.shape(op.inputs[1]),
619          grad,
620          dilations=op.get_attr("dilations"),
621          strides=op.get_attr("strides"),
622          padding=op.get_attr("padding"),
623          explicit_paddings=op.get_attr("explicit_paddings"),
624          data_format=op.get_attr("data_format"))
625  ]
626
627
628@ops.RegisterGradient("Dilation2D")
629def _Dilation2DGrad(op, grad):
630  return [
631      nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
632                                       op.get_attr("strides"),
633                                       op.get_attr("rates"),
634                                       op.get_attr("padding")),
635      nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
636                                        op.get_attr("strides"),
637                                        op.get_attr("rates"),
638                                        op.get_attr("padding"))
639  ]
640
641
642@ops.RegisterGradient("LRN")
643def _LRNGrad(op, grad):
644  depth_radius = op.get_attr("depth_radius")
645  bias = op.get_attr("bias")
646  alpha = op.get_attr("alpha")
647  beta = op.get_attr("beta")
648  return [
649      gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias,
650                          alpha, beta)
651  ]
652
653
654@ops.RegisterGradient("AvgPool")
655def _AvgPoolGrad(op, grad):
656  return gen_nn_ops.avg_pool_grad(
657      array_ops.shape(op.inputs[0]),
658      grad,
659      op.get_attr("ksize"),
660      op.get_attr("strides"),
661      op.get_attr("padding"),
662      data_format=op.get_attr("data_format"))
663
664
665@ops.RegisterGradient("AvgPoolGrad")
666def _AvgPoolGradGrad(op, grad):
667  return (array_ops.stop_gradient(op.inputs[0]),
668          gen_nn_ops.avg_pool(
669              grad,
670              op.get_attr("ksize"),
671              op.get_attr("strides"),
672              op.get_attr("padding"),
673              data_format=op.get_attr("data_format")))
674
675
676@ops.RegisterGradient("MaxPool")
677def _MaxPoolGrad(op, grad):
678  return gen_nn_ops.max_pool_grad(
679      op.inputs[0],
680      op.outputs[0],
681      grad,
682      op.get_attr("ksize"),
683      op.get_attr("strides"),
684      padding=op.get_attr("padding"),
685      explicit_paddings=op.get_attr("explicit_paddings"),
686      data_format=op.get_attr("data_format"))
687
688
689@ops.RegisterGradient("MaxPoolV2")
690def _MaxPoolGradV2(op, grad):
691  ksize = op.inputs[1]
692  strides = op.inputs[2]
693  return gen_nn_ops.max_pool_grad_v2(
694      op.inputs[0],
695      op.outputs[0],
696      grad,
697      ksize,
698      strides,
699      padding=op.get_attr("padding"),
700      data_format=op.get_attr("data_format")), None, None
701
702
703@ops.RegisterGradient("MaxPoolWithArgmax")
704def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
705  del unused_argmax_grad
706  return gen_nn_ops.max_pool_grad_with_argmax(
707      op.inputs[0],
708      grad,
709      op.outputs[1],
710      op.get_attr("ksize"),
711      op.get_attr("strides"),
712      padding=op.get_attr("padding"),
713      include_batch_in_index=op.get_attr("include_batch_in_index"))
714
715
716@ops.RegisterGradient("MaxPoolGrad")
717def _MaxPoolGradGrad(op, grad):
718  return (array_ops.zeros_like(op.inputs[0]),
719          array_ops.zeros_like(op.inputs[1]),
720          gen_nn_ops.max_pool_grad_grad(
721              op.inputs[0],
722              op.inputs[1],
723              grad,
724              op.get_attr("ksize"),
725              op.get_attr("strides"),
726              padding=op.get_attr("padding"),
727              data_format=op.get_attr("data_format")))
728
729
730@ops.RegisterGradient("MaxPoolGradV2")
731def _MaxPoolGradGradV2(op, grad):
732  ksize = op.inputs[3]
733  strides = op.inputs[4]
734  return (array_ops.zeros_like(op.inputs[0]),
735          array_ops.zeros_like(op.inputs[1]),
736          gen_nn_ops.max_pool_grad_grad_v2(
737              op.inputs[0],
738              op.inputs[1],
739              grad,
740              ksize,
741              strides,
742              padding=op.get_attr("padding"),
743              data_format=op.get_attr("data_format")), None, None)
744
745
746@ops.RegisterGradient("MaxPoolGradGrad")
747def _MaxPoolGradGradGrad(op, grad):
748  return (array_ops.zeros_like(op.inputs[0]),
749          array_ops.zeros_like(op.inputs[1]),
750          gen_nn_ops.max_pool_grad(
751              op.inputs[0],
752              op.inputs[1],
753              grad,
754              op.get_attr("ksize"),
755              op.get_attr("strides"),
756              padding=op.get_attr("padding"),
757              data_format=op.get_attr("data_format")))
758
759
760@ops.RegisterGradient("FractionalMaxPool")
761def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
762  """Returns gradient for FractionalMaxPool.
763
764  Since FractionalMaxPool has three outputs, there are three gradients passed in
765  for each of the outputs. Only the first one is useful, the other two gradients
766  are empty.
767
768  Args:
769    op: The FractionalMaxPoolOp.
770    grad_0: Gradient with respect to op.outputs[0]
771    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
772    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
773
774  Returns:
775    Input backprop for FractionalMaxPool op.
776  """
777  return gen_nn_ops.fractional_max_pool_grad(
778      op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
779      op.get_attr("overlapping"))
780
781
782@ops.RegisterGradient("FractionalAvgPool")
783def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
784  """Returns gradient for FractionalAvgPool.
785
786  Since FractionalAvgPool has three outputs, there are three gradients passed in
787  for each of the outputs. Only the first one is useful, the other two gradients
788  are empty.
789
790  Args:
791    op: The FractionalAvgPoolOp.
792    grad_0: Gradient with respect to op.outputs[0]
793    unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty.
794    unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty.
795
796  Returns:
797    Input backprop for FractionalAvgPool op.
798  """
799  return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0,
800                                             op.outputs[1], op.outputs[2],
801                                             op.get_attr("overlapping"))
802
803
804@ops.RegisterGradient("BatchNormWithGlobalNormalization")
805def _BatchNormWithGlobalNormalizationGrad(op, grad):
806  """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization.
807
808  We do not backprop anything for the mean and var intentionally as they are
809  not being trained with backprop in the operation.
810
811  Args:
812    op: The BatchNormOp for which we need to generate gradients.
813    grad: Tensor.  The gradients passed to the BatchNormOp.
814
815  Returns:
816    dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon)))
817    dm: Backprop for mean, which is
818        sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon))
819    dv: Backprop for variance, which is
820        sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2)
821    db: Backprop for beta, which is grad reduced in all except the
822        last dimension.
823    dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon)))
824  """
825  dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad(
826      op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad,
827      op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization"))
828  return dx, dm, dv, db, dg
829
830
831def _BaseFusedBatchNormGrad(op, version, *grad):
832  """Return the gradients for the 3 inputs of BatchNorm.
833
834  Args:
835    op: The BatchNormOp for which we need to compute gradients.
836    version: Integer indicating which version to use of the fused batch
837      norm gradient.
838    *grad: An argument list for tensors of gradients wrt the outputs
839      with grad[0] as grad_y.
840
841  Returns:
842    grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) *
843            [grad_y - mean(grad_y) - (x - mean(x)) *
844            mean(grad_y * (x - mean(x))) / (variance + epsilon)]
845            in training mode; grad_y * scale * rsqrt(pop_variance + epsilon)
846            in freeze mode.
847
848    grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) *
849                rsqrt(variance + epsilon)) in training mode;
850                sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon))
851                in freeze mode.
852
853    grad_offset: gradient for offset, which is sum(grad_y) in training mode;
854                 sum(grad_y) in freeze mode.
855  """
856  x = op.inputs[0]
857  grad_y = grad[0]
858  scale = op.inputs[1]
859  epsilon = op.get_attr("epsilon")
860  data_format = op.get_attr("data_format")
861  is_training = op.get_attr("is_training")
862  if version == 2:
863    grad_fun = gen_nn_ops.fused_batch_norm_grad_v3
864  elif version == 1:
865    grad_fun = gen_nn_ops.fused_batch_norm_grad_v2
866  else:
867    grad_fun = gen_nn_ops.fused_batch_norm_grad
868  if is_training:
869    args = {
870        "y_backprop": grad_y,
871        "x": x,
872        "scale": scale,
873        "reserve_space_1": op.outputs[3],
874        "reserve_space_2": op.outputs[4],
875        "epsilon": epsilon,
876        "data_format": data_format,
877        "is_training": is_training
878    }
879    if version == 2:
880      args["reserve_space_3"] = op.outputs[5]
881    dx, dscale, doffset, _, _ = grad_fun(**args)
882  else:
883    pop_mean = op.inputs[3]
884    pop_var = op.inputs[4]
885    if data_format == b"NCHW":
886      x = array_ops.transpose(x, [0, 2, 3, 1])
887      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1])
888    elif data_format == b"NCDHW":
889      x = array_ops.transpose(x, [0, 2, 3, 4, 1])
890      grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1])
891    target_data_format = ("NHWC" if data_format in (b"NCHW",
892                                                    b"NHWC") else "NDHWC")
893    args = {
894        "y_backprop": grad_y,
895        "x": x,
896        "scale": scale,
897        "reserve_space_1": pop_mean,
898        "reserve_space_2": pop_var,
899        "epsilon": epsilon,
900        "data_format": target_data_format,
901        "is_training": is_training
902    }
903    if version == 2:
904      args["reserve_space_3"] = op.outputs[5]
905    dx, dscale, doffset, _, _ = grad_fun(**args)
906    if data_format == b"NCHW":
907      dx = array_ops.transpose(dx, [0, 3, 1, 2])
908    elif data_format == b"NCDHW":
909      dx = array_ops.transpose(dx, [0, 4, 1, 2, 3])
910  return dx, dscale, doffset, None, None
911
912
913@ops.RegisterGradient("FusedBatchNorm")
914def _FusedBatchNormGrad(op, *grad):
915  return _BaseFusedBatchNormGrad(op, 0, *grad)
916
917
918@ops.RegisterGradient("FusedBatchNormV2")
919def _FusedBatchNormV2Grad(op, *grad):
920  return _BaseFusedBatchNormGrad(op, 1, *grad)
921
922
923@ops.RegisterGradient("FusedBatchNormV3")
924def _FusedBatchNormV3Grad(op, *grad):
925  return _BaseFusedBatchNormGrad(op, 2, *grad)
926
927
928def _BatchNormGrad(grad_y,
929                   x,
930                   scale,
931                   pop_mean,
932                   pop_var,
933                   epsilon,
934                   data_format,
935                   is_training=True):
936  """Returns the gradients for the 3 inputs of BatchNorm.
937
938  Args:
939    grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y.
940    x: A `Tensor` of 4 or 5 dimensions for x.
941    scale: A `Tensor` of 1 dimension for scaling.
942    pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
943      is_training=False.
944    pop_var: A `Tensor` of 1 dimension for the population variance. Only used
945      when is_training=False.
946    epsilon: A small float number added to the variance of x.
947    data_format: The data format for input. Either b"NHWC" or b"NCHW".
948    is_training: A bool value to indicate the operation is for training
949      (default) or inference.
950
951  Returns:
952    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
953    for x, grad_scale the gradient for scale, and grad_offset the gradient
954    for offset.
955  """
956  x_dtype = x.dtype.base_dtype
957  if x_dtype == dtypes.float16:
958    # float16 math is too imprecise, so we do the batch norm gradient
959    # computations in float32.
960    x = math_ops.cast(x, dtypes.float32)
961    grad_y = math_ops.cast(grad_y, dtypes.float32)
962  if is_training:
963    if data_format == b"NHWC":
964      keepdims = False
965      reduce_axis = [0, 1, 2]
966    elif data_format == b"NDHWC":
967      keepdims = False
968      reduce_axis = [0, 1, 2, 3]
969    elif data_format == b"NCHW":
970      keepdims = True
971      reduce_axis = [0, 2, 3]
972      shape = [1, array_ops.size(scale), 1, 1]
973      scale = array_ops.reshape(scale, shape)
974    else:
975      keepdims = True
976      reduce_axis = [0, 2, 3, 4]
977      shape = [1, array_ops.size(scale), 1, 1, 1]
978      scale = array_ops.reshape(scale, shape)
979    mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
980    mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
981    var_x = math_ops.reduce_mean(
982        math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
983        reduce_axis,
984        keepdims=keepdims)
985    grad_y_offset = grad_y - mean_grad_y
986    x_offset = x - mean_x
987    mean = math_ops.reduce_mean(
988        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
989    grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
990        grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
991    grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
992        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
993    if data_format == b"NCHW" or data_format == b"NCDHW":
994      grad_scale = array_ops.squeeze(grad_scale)
995    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
996    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
997  else:
998    if data_format == b"NHWC":
999      reduce_axis = [0, 1, 2]
1000    elif data_format == b"NDHWC":
1001      reduce_axis = [0, 1, 2, 3]
1002    elif data_format == b"NCHW":
1003      reduce_axis = [0, 2, 3]
1004      shape = [1, array_ops.size(pop_mean), 1, 1]
1005      pop_mean = array_ops.reshape(pop_mean, shape)
1006      pop_var = array_ops.reshape(pop_var, shape)
1007      scale = array_ops.reshape(scale, shape)
1008    else:
1009      reduce_axis = [0, 2, 3, 4]
1010      shape = [1, array_ops.size(pop_mean), 1, 1, 1]
1011      pop_mean = array_ops.reshape(pop_mean, shape)
1012      pop_var = array_ops.reshape(pop_var, shape)
1013      scale = array_ops.reshape(scale, shape)
1014
1015    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
1016    var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
1017    grad_scale = math_ops.reduce_sum(
1018        grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
1019    grad_x = grad_y * scale * var_rsqrt
1020    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
1021
1022
1023@ops.RegisterGradient("FusedBatchNormGrad")
1024def _FusedBatchNormGradGrad(op, *grad):
1025  """Returns the gradients for the 3 inputs of FusedBatchNormGrad.
1026
1027  Args:
1028    op: The FusedBatchNormGradOp for which we need to compute gradients.
1029    *grad: An argument list for tensors of gradients wrt the outputs with
1030      grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as
1031      grad_grad_offset.
1032
1033  Returns:
1034    A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y
1035    is the gradient for grad_y, grad_x the gradient for x, grad_scale the
1036    gradient for scale.
1037  """
1038  data_format = op.get_attr("data_format")
1039  epsilon = op.get_attr("epsilon")
1040  is_training = op.get_attr("is_training")
1041  grad_y = op.inputs[0]
1042  x = op.inputs[1]
1043  scale = op.inputs[2]
1044  pop_mean = op.inputs[3]
1045  pop_var = op.inputs[4]
1046  grad_grad_x = grad[0]
1047  grad_grad_scale = grad[1]
1048  grad_grad_offset = grad[2]
1049  with backprop.GradientTape() as tape:
1050    tape.watch(grad_y)
1051    tape.watch(x)
1052    tape.watch(scale)
1053    grad_x, grad_scale, grad_offset = _BatchNormGrad(
1054        grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
1055    grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
1056  grad_grad_y, grad_x, grad_scale = tape.gradient(
1057      [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
1058  return grad_grad_y, grad_x, grad_scale, None, None
1059
1060
1061@ops.RegisterGradient("FusedBatchNormGradV2")
1062def _FusedBatchNormGradGradV2(op, *grad):
1063  return _FusedBatchNormGradGrad(op, *grad)
1064
1065
1066@ops.RegisterGradient("FusedBatchNormGradV3")
1067def _FusedBatchNormGradGradV3(op, *grad):
1068  grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad)
1069  return grad_grad_y, grad_x, grad_scale, None, None, None
1070
1071
1072@ops.RegisterGradient("L2Loss")
1073def _L2LossGrad(op, grad):
1074  """Return the gradients for L2Loss.
1075
1076  Args:
1077    op: The L2LossOp for which we need to generate gradients.
1078    grad: Tensor containing a single number.
1079
1080  Returns:
1081    The gradient, which is (x * grad).
1082  """
1083  return op.inputs[0] * grad
1084
1085
1086@ops.RegisterGradient("TopK")
1087@ops.RegisterGradient("TopKV2")
1088def _TopKGrad(op, grad, _):
1089  """Return the gradients for TopK.
1090
1091  Args:
1092    op: The TopKOp for which we need to generate gradients.
1093    grad: Tensor. The gradients passed to the TopKOp.
1094
1095  Returns:
1096    A list of two tensors, the first being the gradient w.r.t to the input and
1097    TopK, and the second being the gradient w.r.t. to the indices (all zero).
1098  """
1099  in_shape = array_ops.shape(op.inputs[0])
1100  ind_shape = array_ops.shape(op.outputs[1])
1101
1102  # int32 is not supported on GPU hence up-casting
1103  ind_lastdim = array_ops.gather(
1104      math_ops.cast(ind_shape, dtypes.int64),
1105      array_ops.size(ind_shape) - 1)
1106  # Flatten indices to 2D.
1107  ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
1108
1109  in_lastdim = array_ops.gather(
1110      math_ops.cast(in_shape, dtypes.int64),
1111      array_ops.size(in_shape) - 1)
1112  outerdim = array_ops.shape(ind_2d)[0]
1113  # Compute linear indices (flattened to 1D).
1114  ind = array_ops.reshape(
1115      ind_2d + math_ops.cast(
1116          array_ops.expand_dims(
1117              math_ops.range(0,
1118                             math_ops.cast(outerdim, dtypes.int64) * in_lastdim,
1119                             in_lastdim), -1), dtypes.int32), [-1])
1120
1121  # Substitute grad to appropriate locations and fill the rest with zeros,
1122  # finally reshaping it to the original input shape.
1123  return [
1124      array_ops.reshape(
1125          array_ops.scatter_nd(
1126              array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]),
1127              [math_ops.reduce_prod(in_shape)]), in_shape),
1128      array_ops.zeros([], dtype=dtypes.int32)
1129  ]
1130
1131
1132@ops.RegisterGradient("ApproxTopK")
1133def _ApproxTopKGradient(op, grad, _):
1134  """Return the gradients for ApproxTopK.
1135
1136  Args:
1137    op: The ApproxTopK for which we need to generate gradients.
1138    grad: The gradients for backprop.
1139
1140  Returns:
1141    Scattered gradient based on the top-k indices.
1142  """
1143  # The code below is to generate the correct index and value mapping for
1144  # scatter_nd to work properly.
1145  #
1146  # We use static evaluations as much as possible to reduce the runtime cost.
1147  # That's said, use operation.shape instead of array_ops.shape;
1148  # and use functools.reduce(operator.mul, ...) instead of math_ops.reduce_prod
1149  idx_shape = op.outputs[1].shape
1150  lifted_idx_shape = idx_shape + [1]
1151  flat_shape_len = functools.reduce(operator.mul, idx_shape)
1152  rank = idx_shape.rank
1153  reduction_dim = op.get_attr("reduction_dimension")
1154  if reduction_dim < 0:
1155    reduction_dim = rank + reduction_dim
1156
1157  def GetLiftedIdx(d):
1158    if d == reduction_dim:
1159      return array_ops.reshape(op.outputs[1], lifted_idx_shape)
1160    iota_len = idx_shape[d]
1161    iota_shape = list(itertools.repeat(1, rank + 1))
1162    iota_shape[d] = iota_len
1163    iota = array_ops.reshape(math_ops.range(iota_len), iota_shape)
1164    return array_ops.broadcast_to(iota, lifted_idx_shape)
1165
1166  lifted_idx = array_ops.concat(
1167      list(GetLiftedIdx(d) for d in range(rank)), axis=rank)
1168  flat_idx = array_ops.reshape(lifted_idx, [flat_shape_len, rank])
1169  flat_grad = array_ops.reshape(grad, [flat_shape_len])
1170  return array_ops.scatter_nd(flat_idx, flat_grad, op.inputs[0].shape)
1171
1172
1173@ops.RegisterGradient("NthElement")
1174def _NthElementGrad(op, grad):
1175  """Return the gradients for NthElement.
1176
1177  Args:
1178    op: The NthElementOp for which we need to generate gradients.
1179    grad: Tensor. The gradients passed to the NthElementOp
1180
1181  Returns:
1182    A list of two tensors, the first being the gradient w.r.t. the input,
1183    the second being the gradient w.r.t. the N (None).
1184  """
1185  input = op.inputs[0]  # pylint: disable=redefined-builtin
1186  output = op.outputs[0]
1187
1188  # Compute the number of elements which equal to output in each reduction
1189  # dimension. If there are multiple elements then the gradient will be
1190  # divided between them.
1191  indicators = math_ops.cast(
1192      math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
1193
1194  grad = array_ops.expand_dims(grad, -1)
1195  num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
1196
1197  return [math_ops.divide(indicators, num_selected) * grad, None]
1198
1199
1200def _MeanAggregator(inputs, segments):
1201  """Replaces each segment with its mean along the last axis.
1202
1203  Specifically, each value in the `inputs` tensor gets replaced by the mean
1204  value computed from the values that belong to the same segment.
1205
1206  Args:
1207   inputs: A 2-tensor. Aggregation is done over dimension 1.
1208   segments: A 2-tensor, same shape as `input`.
1209
1210  Returns:
1211    The result, same shape and type as `inputs`.
1212  """
1213  result = []
1214  for inputs_i, segments_i in zip(
1215      array_ops.split(inputs, inputs.shape[0]),
1216      array_ops.split(segments, segments.shape[0])):
1217    # Note that we do not use tf.math.segment_mean, as it has no TPU support.
1218    means_i = math_ops.unsorted_segment_mean(
1219        inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1)
1220    result.append(
1221        array_ops.reshape(array_ops.gather(means_i, segments_i), [-1]))
1222  return array_ops.stack(result, axis=0)
1223
1224
1225# We have to register the gradients for these ops so that tensorflow will know
1226# how to differentiate them.
1227@ops.RegisterGradient("IsotonicRegression")
1228def _IsotonicRegressionGrad(op, grad_output, grad_segments):
1229  """Gradient for the isotonic regression function.
1230
1231  Args:
1232    op: The IsotonicRegression tensorflow op.
1233    grad_output: Tensor of incoming gradients with respect to the output.
1234    grad_segments: Tensor of incoming gradients with respect to the segments.
1235
1236  Returns:
1237    A tensor, same size as `grad_output` with the gradient with respect to
1238    the input.
1239  """
1240  del grad_segments  # Discrete, non-differentiable.
1241  segments = op.outputs[1]
1242  return _MeanAggregator(grad_output, segments)
1243