1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in nn_ops.py.""" 16 17import functools 18import itertools 19import operator 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_nn_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn_ops 28 29 30@ops.RegisterGradient("Conv2DBackpropInput") 31def _Conv2DBackpropInputGrad(op, grad): 32 """The derivatives for deconvolution. 33 34 Args: 35 op: the Deconvolution op. 36 grad: the tensor representing the gradient w.r.t. the output 37 38 Returns: 39 the gradients w.r.t. the input and the filter 40 """ 41 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 42 # functions for performance reasons in Eager mode. See _Conv2DGrad. 43 return [ 44 None, 45 gen_nn_ops.conv2d_backprop_filter( 46 grad, 47 array_ops.shape(op.inputs[1]), 48 op.inputs[2], 49 dilations=op.get_attr("dilations"), 50 strides=op.get_attr("strides"), 51 padding=op.get_attr("padding"), 52 explicit_paddings=op.get_attr("explicit_paddings"), 53 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 54 data_format=op.get_attr("data_format").decode()), 55 gen_nn_ops.conv2d( 56 grad, 57 op.inputs[1], 58 dilations=op.get_attr("dilations"), 59 strides=op.get_attr("strides"), 60 padding=op.get_attr("padding"), 61 explicit_paddings=op.get_attr("explicit_paddings"), 62 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 63 data_format=op.get_attr("data_format").decode()) 64 ] 65 66 67@ops.RegisterGradient("Conv2DBackpropFilter") 68def _Conv2DBackpropFilterGrad(op, grad): 69 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 70 # functions for performance reasons in Eager mode. See _Conv2DGrad. 71 return [ 72 gen_nn_ops.conv2d_backprop_input( 73 array_ops.shape(op.inputs[0]), 74 grad, 75 op.inputs[2], 76 dilations=op.get_attr("dilations"), 77 strides=op.get_attr("strides"), 78 padding=op.get_attr("padding"), 79 explicit_paddings=op.get_attr("explicit_paddings"), 80 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 81 data_format=op.get_attr("data_format").decode()), None, 82 gen_nn_ops.conv2d( 83 op.inputs[0], 84 grad, 85 dilations=op.get_attr("dilations"), 86 strides=op.get_attr("strides"), 87 padding=op.get_attr("padding"), 88 explicit_paddings=op.get_attr("explicit_paddings"), 89 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 90 data_format=op.get_attr("data_format").decode()) 91 ] 92 93 94@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput") 95def _DepthwiseConv2dNativeBackpropInputGrad(op, grad): 96 """The derivatives for deconvolution. 97 98 Args: 99 op: the Deconvolution op. 100 grad: the tensor representing the gradient w.r.t. the output 101 102 Returns: 103 the gradients w.r.t. the input and the filter 104 """ 105 return [ 106 None, 107 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 108 grad, 109 array_ops.shape(op.inputs[1]), 110 op.inputs[2], 111 dilations=op.get_attr("dilations"), 112 strides=op.get_attr("strides"), 113 padding=op.get_attr("padding"), 114 explicit_paddings=op.get_attr("explicit_paddings"), 115 data_format=op.get_attr("data_format")), 116 gen_nn_ops.depthwise_conv2d_native( 117 grad, 118 op.inputs[1], 119 dilations=op.get_attr("dilations"), 120 strides=op.get_attr("strides"), 121 padding=op.get_attr("padding"), 122 explicit_paddings=op.get_attr("explicit_paddings"), 123 data_format=op.get_attr("data_format")) 124 ] 125 126 127@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter") 128def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad): 129 return [ 130 gen_nn_ops.depthwise_conv2d_native_backprop_input( 131 array_ops.shape(op.inputs[0]), 132 grad, 133 op.inputs[2], 134 dilations=op.get_attr("dilations"), 135 strides=op.get_attr("strides"), 136 padding=op.get_attr("padding"), 137 explicit_paddings=op.get_attr("explicit_paddings"), 138 data_format=op.get_attr("data_format")), None, 139 gen_nn_ops.depthwise_conv2d_native( 140 op.inputs[0], 141 grad, 142 dilations=op.get_attr("dilations"), 143 strides=op.get_attr("strides"), 144 padding=op.get_attr("padding"), 145 explicit_paddings=op.get_attr("explicit_paddings"), 146 data_format=op.get_attr("data_format")) 147 ] 148 149 150@ops.RegisterGradient("Conv3D") 151def _Conv3DGrad(op, grad): 152 data_format = op.get_attr("data_format").decode() 153 return [ 154 nn_ops.conv3d_backprop_input_v2( 155 array_ops.shape(op.inputs[0]), 156 op.inputs[1], 157 grad, 158 dilations=op.get_attr("dilations"), 159 strides=op.get_attr("strides"), 160 padding=op.get_attr("padding"), 161 data_format=data_format), 162 nn_ops.conv3d_backprop_filter_v2( 163 op.inputs[0], 164 array_ops.shape(op.inputs[1]), 165 grad, 166 dilations=op.get_attr("dilations"), 167 strides=op.get_attr("strides"), 168 padding=op.get_attr("padding"), 169 data_format=data_format) 170 ] 171 172 173@ops.RegisterGradient("Conv3DBackpropInputV2") 174def _Conv3DBackpropInputGrad(op, grad): 175 data_format = op.get_attr("data_format").decode() 176 return [ 177 None, 178 nn_ops.conv3d_backprop_filter_v2( 179 grad, 180 array_ops.shape(op.inputs[1]), 181 op.inputs[2], 182 dilations=op.get_attr("dilations"), 183 strides=op.get_attr("strides"), 184 padding=op.get_attr("padding"), 185 data_format=data_format), 186 nn_ops.conv3d( 187 grad, 188 op.inputs[1], 189 dilations=op.get_attr("dilations"), 190 strides=op.get_attr("strides"), 191 padding=op.get_attr("padding"), 192 data_format=data_format) 193 ] 194 195 196@ops.RegisterGradient("Conv3DBackpropFilterV2") 197def _Conv3DBackpropFilterGrad(op, grad): 198 data_format = op.get_attr("data_format").decode() 199 return [ 200 nn_ops.conv3d_backprop_input_v2( 201 array_ops.shape(op.inputs[0]), 202 grad, 203 op.inputs[2], 204 dilations=op.get_attr("dilations"), 205 strides=op.get_attr("strides"), 206 padding=op.get_attr("padding"), 207 data_format=data_format), None, 208 nn_ops.conv3d( 209 op.inputs[0], 210 grad, 211 dilations=op.get_attr("dilations"), 212 strides=op.get_attr("strides"), 213 padding=op.get_attr("padding"), 214 data_format=data_format) 215 ] 216 217 218@ops.RegisterGradient("AvgPool3D") 219def _AvgPool3DGrad(op, grad): 220 return gen_nn_ops.avg_pool3d_grad( 221 array_ops.shape(op.inputs[0]), 222 grad, 223 ksize=op.get_attr("ksize"), 224 strides=op.get_attr("strides"), 225 padding=op.get_attr("padding"), 226 data_format=op.get_attr("data_format").decode()) 227 228 229@ops.RegisterGradient("AvgPool3DGrad") 230def _AvgPool3DGradGrad(op, grad): 231 return (array_ops.stop_gradient(op.inputs[0]), 232 gen_nn_ops.avg_pool3d( 233 grad, 234 op.get_attr("ksize"), 235 op.get_attr("strides"), 236 op.get_attr("padding"), 237 data_format=op.get_attr("data_format").decode())) 238 239 240@ops.RegisterGradient("MaxPool3D") 241def _MaxPool3DGrad(op, grad): 242 return gen_nn_ops.max_pool3d_grad( 243 op.inputs[0], 244 op.outputs[0], 245 grad, 246 ksize=op.get_attr("ksize"), 247 strides=op.get_attr("strides"), 248 padding=op.get_attr("padding"), 249 data_format=op.get_attr("data_format").decode()) 250 251 252@ops.RegisterGradient("MaxPool3DGrad") 253def _MaxPool3DGradGrad(op, grad): 254 return (array_ops.zeros_like(op.inputs[0]), 255 array_ops.zeros_like(op.inputs[1]), 256 gen_nn_ops.max_pool3d_grad_grad( 257 op.inputs[0], 258 op.inputs[1], 259 grad, 260 op.get_attr("ksize"), 261 op.get_attr("strides"), 262 padding=op.get_attr("padding"), 263 data_format=op.get_attr("data_format").decode())) 264 265 266@ops.RegisterGradient("MaxPool3DGradGrad") 267def _MaxPool3DGradGradGrad(op, grad): 268 return (array_ops.zeros_like(op.inputs[0]), 269 array_ops.zeros_like(op.inputs[1]), 270 gen_nn_ops.max_pool3d_grad( 271 op.inputs[0], 272 op.inputs[1], 273 grad, 274 op.get_attr("ksize"), 275 op.get_attr("strides"), 276 padding=op.get_attr("padding"), 277 data_format=op.get_attr("data_format").decode())) 278 279 280@ops.RegisterGradient("Softmax") 281def _SoftmaxGrad(op, grad_softmax): 282 """The derivative of the softmax nonlinearity. 283 284 We assume that probs is of shape [batch_size * dim] 285 The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax'). 286 This matrix is diagonal minus a rank one matrix, so it is easy to implement 287 as follows: 288 289 grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax 290 291 Args: 292 op: the Softmax op. 293 grad_softmax: the tensor representing the gradient w.r.t. the softmax 294 output. 295 296 Returns: 297 gradient w.r.t the input to the softmax 298 299 """ 300 softmax = op.outputs[0] 301 sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True) 302 return (grad_softmax - sum_channels) * softmax 303 304 305@ops.RegisterGradient("LogSoftmax") 306def _LogSoftmaxGrad(op, grad): 307 """The gradient for log_softmax. 308 309 log_softmax = input - log(sum(exp(input)) 310 dlog_softmax/dinput = diag - softmax(input) 311 312 Args: 313 op: The log softmax op. 314 grad: The tensor representing the gradient w.r.t. the output. 315 316 Returns: 317 The gradients w.r.t. the input. 318 """ 319 softmax = math_ops.exp(op.outputs[0]) 320 return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax 321 322 323@ops.RegisterGradient("BiasAdd") 324def _BiasAddGrad(op, received_grad): 325 """Return the gradients for the 2 inputs of bias_op. 326 327 The first input of unused_bias_op is the tensor t, and its gradient is 328 just the gradient the unused_bias_op received. 329 330 The second input of unused_bias_op is the bias vector which has one fewer 331 dimension than "received_grad" (the batch dimension.) Its gradient is the 332 received gradient Summed on the batch dimension, which is the first dimension. 333 334 Args: 335 op: The BiasOp for which we need to generate gradients. 336 received_grad: Tensor. The gradients passed to the BiasOp. 337 338 Returns: 339 Two tensors, the first one for the "tensor" input of the BiasOp, 340 the second one for the "bias" input of the BiasOp. 341 """ 342 try: 343 data_format = op.get_attr("data_format") 344 except ValueError: 345 data_format = None 346 return (received_grad, 347 gen_nn_ops.bias_add_grad( 348 out_backprop=received_grad, data_format=data_format)) 349 350 351@ops.RegisterGradient("BiasAddGrad") 352def _BiasAddGradGrad(op, received_grad): 353 """Gradient for the BiasAddGrad op. 354 355 Args: 356 op: BiasAddGrad op for which we are calculating gradients. 357 received_grad: The gradients passed to the BiasAddGrad op. 358 359 Returns: 360 A single gradient Tensor for the input to BiasAddGrad (which 361 is the gradient of the bias term in BiasAdd) 362 """ 363 364 try: 365 data_format = op.get_attr("data_format") 366 except ValueError: 367 data_format = None 368 369 shape = array_ops.shape(op.inputs[0]) 370 bias_shape = array_ops.shape(received_grad) 371 372 if data_format == b"NCHW": 373 expanded_shape = array_ops.concat([ 374 array_ops.ones_like(shape[:1]), bias_shape, 375 array_ops.ones_like(shape[2:]) 376 ], 0) 377 tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0) 378 else: 379 expanded_shape = array_ops.concat( 380 [array_ops.ones_like(shape[:-1]), bias_shape], 0) 381 tile_mults = array_ops.concat([shape[:-1], [1]], 0) 382 383 expanded_grad = array_ops.reshape(received_grad, expanded_shape) 384 return array_ops.tile(expanded_grad, tile_mults) 385 386 387@ops.RegisterGradient("BiasAddV1") 388def _BiasAddGradV1(unused_bias_op, received_grad): 389 """Return the gradients for the 2 inputs of bias_op. 390 391 The first input of unused_bias_op is the tensor t, and its gradient is 392 just the gradient the unused_bias_op received. 393 394 The second input of unused_bias_op is the bias vector which has one fewer 395 dimension than "received_grad" (the batch dimension.) Its gradient is the 396 received gradient Summed on the batch dimension, which is the first dimension. 397 398 Args: 399 unused_bias_op: The BiasOp for which we need to generate gradients. 400 received_grad: Tensor. The gradients passed to the BiasOp. 401 402 Returns: 403 Two tensors, the first one for the "tensor" input of the BiasOp, 404 the second one for the "bias" input of the BiasOp. 405 """ 406 reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) 407 return (received_grad, math_ops.reduce_sum(received_grad, 408 reduction_dim_tensor)) 409 410 411@ops.RegisterGradient("Relu") 412def _ReluGrad(op, grad): 413 return gen_nn_ops.relu_grad(grad, op.outputs[0]) 414 415 416@ops.RegisterGradient("EluGrad") 417def _EluGradGrad(op, grad): 418 elu_x = op.inputs[1] 419 return (gen_nn_ops.elu_grad(grad, elu_x), 420 array_ops.where( 421 elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x))) 422 423 424@ops.RegisterGradient("SeluGrad") 425def _SeluGradGrad(op, grad): 426 selu_x = op.inputs[1] 427 return (gen_nn_ops.selu_grad(grad, selu_x), 428 array_ops.where( 429 selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x))) 430 431 432@ops.RegisterGradient("Relu6") 433def _Relu6Grad(op, grad): 434 return gen_nn_ops.relu6_grad(grad, op.outputs[0]) 435 436 437@ops.RegisterGradient("Relu6Grad") 438def _Relu6GradGrad(op, grad): 439 x = op.inputs[1] 440 return (gen_nn_ops.relu6_grad(grad, x), array_ops.zeros_like(x)) 441 442 443@ops.RegisterGradient("LeakyRelu") 444def _LeakyReluGrad(op, grad): 445 x = op.inputs[0] 446 alpha = op.get_attr("alpha") 447 return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) 448 449 450@ops.RegisterGradient("LeakyReluGrad") 451def _LeakyReluGradGrad(op, grad): 452 x = op.inputs[1] 453 alpha = op.get_attr("alpha") 454 return (gen_nn_ops.leaky_relu_grad(grad, x, 455 alpha=alpha), array_ops.zeros_like(x)) 456 457 458@ops.RegisterGradient("Elu") 459def _EluGrad(op, grad): 460 return gen_nn_ops.elu_grad(grad, op.outputs[0]) 461 462 463@ops.RegisterGradient("Selu") 464def _SeluGrad(op, grad): 465 return gen_nn_ops.selu_grad(grad, op.outputs[0]) 466 467 468@ops.RegisterGradient("Softplus") 469def _SoftplusGrad(op, grad): 470 return grad * math_ops.sigmoid(op.inputs[0]) 471 472 473@ops.RegisterGradient("SoftplusGrad") 474def _SoftplusGradGrad(op, grad): 475 # Let: 476 # y = tf.nn.softplus(x) 477 # dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x)) 478 # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx. 479 dy, x = op.inputs 480 with ops.control_dependencies([grad]): 481 ddy = gen_nn_ops.softplus_grad(grad, x) 482 d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x)) 483 return (ddy, d2x) 484 485 486@ops.RegisterGradient("Softsign") 487def _SoftsignGrad(op, grad): 488 return gen_nn_ops.softsign_grad(grad, op.inputs[0]) 489 490 491@ops.RegisterGradient("ReluGrad") 492def _ReluGradGrad(op, grad): 493 x = op.inputs[1] 494 return (gen_nn_ops.relu_grad(grad, x), array_ops.zeros_like(x)) 495 496 497def _BroadcastMul(vec, mat): 498 """Multiply after broadcasting vec to match dimensions of mat. 499 500 Args: 501 vec: A 1-D tensor of dimension [D0] 502 mat: A 2-D tensor of dimension [D0, D1] 503 504 Returns: 505 A tensor of dimension [D0, D1], the result of vec * mat 506 """ 507 # Reshape vec to [D0, 1] 508 vec = array_ops.expand_dims(vec, -1) 509 return vec * mat 510 511 512@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits") 513def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 514 """Gradient function for SoftmaxCrossEntropyWithLogits.""" 515 # grad_loss is the backprop for cost, and we multiply it with the gradients 516 # (which is output[1]) 517 # grad_grad is the backprop for softmax gradient. 518 # 519 # Second derivative is just softmax derivative w.r.t. logits. 520 softmax_grad = op.outputs[1] 521 grad = _BroadcastMul(grad_loss, softmax_grad) 522 523 logits = op.inputs[0] 524 if (grad_grad is not None and 525 not getattr(grad_grad, "_is_zeros_tensor", False)): 526 softmax = nn_ops.softmax(logits) 527 528 grad += ((grad_grad - array_ops.squeeze( 529 math_ops.matmul( 530 array_ops.expand_dims(grad_grad, 1), 531 array_ops.expand_dims(softmax, 2)), 532 axis=1)) * softmax) 533 534 return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) # pylint: disable=invalid-unary-operand-type 535 536 537@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") 538def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 539 """Gradient function for SparseSoftmaxCrossEntropyWithLogits.""" 540 # grad_loss is the backprop for cost, and we multiply it with the gradients 541 # (which is output[1]) 542 # grad_grad is the backprop for softmax gradient. 543 # There is no gradient for the labels 544 # 545 # Second derivative is just softmax derivative w.r.t. logits. 546 softmax_grad = op.outputs[1] 547 grad = _BroadcastMul(grad_loss, softmax_grad) 548 549 logits = op.inputs[0] 550 if (grad_grad is not None and 551 not getattr(grad_grad, "_is_zeros_tensor", False)): 552 softmax = nn_ops.softmax(logits) 553 554 grad += ((grad_grad - array_ops.squeeze( 555 math_ops.matmul( 556 array_ops.expand_dims(grad_grad, 1), 557 array_ops.expand_dims(softmax, 2)), 558 axis=1)) * softmax) 559 560 return grad, None 561 562 563@ops.RegisterGradient("Conv2D") 564def _Conv2DGrad(op, grad): 565 """Gradient function for Conv2D.""" 566 dilations = op.get_attr("dilations") 567 strides = op.get_attr("strides") 568 padding = op.get_attr("padding") 569 explicit_paddings = op.get_attr("explicit_paddings") 570 use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu") 571 data_format = op.get_attr("data_format") 572 shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) 573 574 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 575 # functions for performance reasons in Eager mode. gen_nn_ops functions take a 576 # `explicit_paddings` parameter, but nn_ops functions do not. So if we were 577 # to use the nn_ops functions, we would have to convert `padding` and 578 # `explicit_paddings` into a single `padding` parameter, increasing overhead 579 # in Eager mode. 580 return [ 581 gen_nn_ops.conv2d_backprop_input( 582 shape_0, 583 op.inputs[1], 584 grad, 585 dilations=dilations, 586 strides=strides, 587 padding=padding, 588 explicit_paddings=explicit_paddings, 589 use_cudnn_on_gpu=use_cudnn_on_gpu, 590 data_format=data_format), 591 gen_nn_ops.conv2d_backprop_filter( 592 op.inputs[0], 593 shape_1, 594 grad, 595 dilations=dilations, 596 strides=strides, 597 padding=padding, 598 explicit_paddings=explicit_paddings, 599 use_cudnn_on_gpu=use_cudnn_on_gpu, 600 data_format=data_format) 601 ] 602 603 604@ops.RegisterGradient("DepthwiseConv2dNative") 605def _DepthwiseConv2dNativeGrad(op, grad): 606 return [ 607 gen_nn_ops.depthwise_conv2d_native_backprop_input( 608 array_ops.shape(op.inputs[0]), 609 op.inputs[1], 610 grad, 611 dilations=op.get_attr("dilations"), 612 strides=op.get_attr("strides"), 613 padding=op.get_attr("padding"), 614 explicit_paddings=op.get_attr("explicit_paddings"), 615 data_format=op.get_attr("data_format")), 616 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 617 op.inputs[0], 618 array_ops.shape(op.inputs[1]), 619 grad, 620 dilations=op.get_attr("dilations"), 621 strides=op.get_attr("strides"), 622 padding=op.get_attr("padding"), 623 explicit_paddings=op.get_attr("explicit_paddings"), 624 data_format=op.get_attr("data_format")) 625 ] 626 627 628@ops.RegisterGradient("Dilation2D") 629def _Dilation2DGrad(op, grad): 630 return [ 631 nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, 632 op.get_attr("strides"), 633 op.get_attr("rates"), 634 op.get_attr("padding")), 635 nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, 636 op.get_attr("strides"), 637 op.get_attr("rates"), 638 op.get_attr("padding")) 639 ] 640 641 642@ops.RegisterGradient("LRN") 643def _LRNGrad(op, grad): 644 depth_radius = op.get_attr("depth_radius") 645 bias = op.get_attr("bias") 646 alpha = op.get_attr("alpha") 647 beta = op.get_attr("beta") 648 return [ 649 gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias, 650 alpha, beta) 651 ] 652 653 654@ops.RegisterGradient("AvgPool") 655def _AvgPoolGrad(op, grad): 656 return gen_nn_ops.avg_pool_grad( 657 array_ops.shape(op.inputs[0]), 658 grad, 659 op.get_attr("ksize"), 660 op.get_attr("strides"), 661 op.get_attr("padding"), 662 data_format=op.get_attr("data_format")) 663 664 665@ops.RegisterGradient("AvgPoolGrad") 666def _AvgPoolGradGrad(op, grad): 667 return (array_ops.stop_gradient(op.inputs[0]), 668 gen_nn_ops.avg_pool( 669 grad, 670 op.get_attr("ksize"), 671 op.get_attr("strides"), 672 op.get_attr("padding"), 673 data_format=op.get_attr("data_format"))) 674 675 676@ops.RegisterGradient("MaxPool") 677def _MaxPoolGrad(op, grad): 678 return gen_nn_ops.max_pool_grad( 679 op.inputs[0], 680 op.outputs[0], 681 grad, 682 op.get_attr("ksize"), 683 op.get_attr("strides"), 684 padding=op.get_attr("padding"), 685 explicit_paddings=op.get_attr("explicit_paddings"), 686 data_format=op.get_attr("data_format")) 687 688 689@ops.RegisterGradient("MaxPoolV2") 690def _MaxPoolGradV2(op, grad): 691 ksize = op.inputs[1] 692 strides = op.inputs[2] 693 return gen_nn_ops.max_pool_grad_v2( 694 op.inputs[0], 695 op.outputs[0], 696 grad, 697 ksize, 698 strides, 699 padding=op.get_attr("padding"), 700 data_format=op.get_attr("data_format")), None, None 701 702 703@ops.RegisterGradient("MaxPoolWithArgmax") 704def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): 705 del unused_argmax_grad 706 return gen_nn_ops.max_pool_grad_with_argmax( 707 op.inputs[0], 708 grad, 709 op.outputs[1], 710 op.get_attr("ksize"), 711 op.get_attr("strides"), 712 padding=op.get_attr("padding"), 713 include_batch_in_index=op.get_attr("include_batch_in_index")) 714 715 716@ops.RegisterGradient("MaxPoolGrad") 717def _MaxPoolGradGrad(op, grad): 718 return (array_ops.zeros_like(op.inputs[0]), 719 array_ops.zeros_like(op.inputs[1]), 720 gen_nn_ops.max_pool_grad_grad( 721 op.inputs[0], 722 op.inputs[1], 723 grad, 724 op.get_attr("ksize"), 725 op.get_attr("strides"), 726 padding=op.get_attr("padding"), 727 data_format=op.get_attr("data_format"))) 728 729 730@ops.RegisterGradient("MaxPoolGradV2") 731def _MaxPoolGradGradV2(op, grad): 732 ksize = op.inputs[3] 733 strides = op.inputs[4] 734 return (array_ops.zeros_like(op.inputs[0]), 735 array_ops.zeros_like(op.inputs[1]), 736 gen_nn_ops.max_pool_grad_grad_v2( 737 op.inputs[0], 738 op.inputs[1], 739 grad, 740 ksize, 741 strides, 742 padding=op.get_attr("padding"), 743 data_format=op.get_attr("data_format")), None, None) 744 745 746@ops.RegisterGradient("MaxPoolGradGrad") 747def _MaxPoolGradGradGrad(op, grad): 748 return (array_ops.zeros_like(op.inputs[0]), 749 array_ops.zeros_like(op.inputs[1]), 750 gen_nn_ops.max_pool_grad( 751 op.inputs[0], 752 op.inputs[1], 753 grad, 754 op.get_attr("ksize"), 755 op.get_attr("strides"), 756 padding=op.get_attr("padding"), 757 data_format=op.get_attr("data_format"))) 758 759 760@ops.RegisterGradient("FractionalMaxPool") 761def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 762 """Returns gradient for FractionalMaxPool. 763 764 Since FractionalMaxPool has three outputs, there are three gradients passed in 765 for each of the outputs. Only the first one is useful, the other two gradients 766 are empty. 767 768 Args: 769 op: The FractionalMaxPoolOp. 770 grad_0: Gradient with respect to op.outputs[0] 771 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 772 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 773 774 Returns: 775 Input backprop for FractionalMaxPool op. 776 """ 777 return gen_nn_ops.fractional_max_pool_grad( 778 op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], 779 op.get_attr("overlapping")) 780 781 782@ops.RegisterGradient("FractionalAvgPool") 783def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 784 """Returns gradient for FractionalAvgPool. 785 786 Since FractionalAvgPool has three outputs, there are three gradients passed in 787 for each of the outputs. Only the first one is useful, the other two gradients 788 are empty. 789 790 Args: 791 op: The FractionalAvgPoolOp. 792 grad_0: Gradient with respect to op.outputs[0] 793 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 794 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 795 796 Returns: 797 Input backprop for FractionalAvgPool op. 798 """ 799 return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, 800 op.outputs[1], op.outputs[2], 801 op.get_attr("overlapping")) 802 803 804@ops.RegisterGradient("BatchNormWithGlobalNormalization") 805def _BatchNormWithGlobalNormalizationGrad(op, grad): 806 """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization. 807 808 We do not backprop anything for the mean and var intentionally as they are 809 not being trained with backprop in the operation. 810 811 Args: 812 op: The BatchNormOp for which we need to generate gradients. 813 grad: Tensor. The gradients passed to the BatchNormOp. 814 815 Returns: 816 dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon))) 817 dm: Backprop for mean, which is 818 sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon)) 819 dv: Backprop for variance, which is 820 sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2) 821 db: Backprop for beta, which is grad reduced in all except the 822 last dimension. 823 dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon))) 824 """ 825 dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad( 826 op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad, 827 op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization")) 828 return dx, dm, dv, db, dg 829 830 831def _BaseFusedBatchNormGrad(op, version, *grad): 832 """Return the gradients for the 3 inputs of BatchNorm. 833 834 Args: 835 op: The BatchNormOp for which we need to compute gradients. 836 version: Integer indicating which version to use of the fused batch 837 norm gradient. 838 *grad: An argument list for tensors of gradients wrt the outputs 839 with grad[0] as grad_y. 840 841 Returns: 842 grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) * 843 [grad_y - mean(grad_y) - (x - mean(x)) * 844 mean(grad_y * (x - mean(x))) / (variance + epsilon)] 845 in training mode; grad_y * scale * rsqrt(pop_variance + epsilon) 846 in freeze mode. 847 848 grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) * 849 rsqrt(variance + epsilon)) in training mode; 850 sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon)) 851 in freeze mode. 852 853 grad_offset: gradient for offset, which is sum(grad_y) in training mode; 854 sum(grad_y) in freeze mode. 855 """ 856 x = op.inputs[0] 857 grad_y = grad[0] 858 scale = op.inputs[1] 859 epsilon = op.get_attr("epsilon") 860 data_format = op.get_attr("data_format") 861 is_training = op.get_attr("is_training") 862 if version == 2: 863 grad_fun = gen_nn_ops.fused_batch_norm_grad_v3 864 elif version == 1: 865 grad_fun = gen_nn_ops.fused_batch_norm_grad_v2 866 else: 867 grad_fun = gen_nn_ops.fused_batch_norm_grad 868 if is_training: 869 args = { 870 "y_backprop": grad_y, 871 "x": x, 872 "scale": scale, 873 "reserve_space_1": op.outputs[3], 874 "reserve_space_2": op.outputs[4], 875 "epsilon": epsilon, 876 "data_format": data_format, 877 "is_training": is_training 878 } 879 if version == 2: 880 args["reserve_space_3"] = op.outputs[5] 881 dx, dscale, doffset, _, _ = grad_fun(**args) 882 else: 883 pop_mean = op.inputs[3] 884 pop_var = op.inputs[4] 885 if data_format == b"NCHW": 886 x = array_ops.transpose(x, [0, 2, 3, 1]) 887 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) 888 elif data_format == b"NCDHW": 889 x = array_ops.transpose(x, [0, 2, 3, 4, 1]) 890 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1]) 891 target_data_format = ("NHWC" if data_format in (b"NCHW", 892 b"NHWC") else "NDHWC") 893 args = { 894 "y_backprop": grad_y, 895 "x": x, 896 "scale": scale, 897 "reserve_space_1": pop_mean, 898 "reserve_space_2": pop_var, 899 "epsilon": epsilon, 900 "data_format": target_data_format, 901 "is_training": is_training 902 } 903 if version == 2: 904 args["reserve_space_3"] = op.outputs[5] 905 dx, dscale, doffset, _, _ = grad_fun(**args) 906 if data_format == b"NCHW": 907 dx = array_ops.transpose(dx, [0, 3, 1, 2]) 908 elif data_format == b"NCDHW": 909 dx = array_ops.transpose(dx, [0, 4, 1, 2, 3]) 910 return dx, dscale, doffset, None, None 911 912 913@ops.RegisterGradient("FusedBatchNorm") 914def _FusedBatchNormGrad(op, *grad): 915 return _BaseFusedBatchNormGrad(op, 0, *grad) 916 917 918@ops.RegisterGradient("FusedBatchNormV2") 919def _FusedBatchNormV2Grad(op, *grad): 920 return _BaseFusedBatchNormGrad(op, 1, *grad) 921 922 923@ops.RegisterGradient("FusedBatchNormV3") 924def _FusedBatchNormV3Grad(op, *grad): 925 return _BaseFusedBatchNormGrad(op, 2, *grad) 926 927 928def _BatchNormGrad(grad_y, 929 x, 930 scale, 931 pop_mean, 932 pop_var, 933 epsilon, 934 data_format, 935 is_training=True): 936 """Returns the gradients for the 3 inputs of BatchNorm. 937 938 Args: 939 grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y. 940 x: A `Tensor` of 4 or 5 dimensions for x. 941 scale: A `Tensor` of 1 dimension for scaling. 942 pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when 943 is_training=False. 944 pop_var: A `Tensor` of 1 dimension for the population variance. Only used 945 when is_training=False. 946 epsilon: A small float number added to the variance of x. 947 data_format: The data format for input. Either b"NHWC" or b"NCHW". 948 is_training: A bool value to indicate the operation is for training 949 (default) or inference. 950 951 Returns: 952 A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient 953 for x, grad_scale the gradient for scale, and grad_offset the gradient 954 for offset. 955 """ 956 x_dtype = x.dtype.base_dtype 957 if x_dtype == dtypes.float16: 958 # float16 math is too imprecise, so we do the batch norm gradient 959 # computations in float32. 960 x = math_ops.cast(x, dtypes.float32) 961 grad_y = math_ops.cast(grad_y, dtypes.float32) 962 if is_training: 963 if data_format == b"NHWC": 964 keepdims = False 965 reduce_axis = [0, 1, 2] 966 elif data_format == b"NDHWC": 967 keepdims = False 968 reduce_axis = [0, 1, 2, 3] 969 elif data_format == b"NCHW": 970 keepdims = True 971 reduce_axis = [0, 2, 3] 972 shape = [1, array_ops.size(scale), 1, 1] 973 scale = array_ops.reshape(scale, shape) 974 else: 975 keepdims = True 976 reduce_axis = [0, 2, 3, 4] 977 shape = [1, array_ops.size(scale), 1, 1, 1] 978 scale = array_ops.reshape(scale, shape) 979 mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) 980 mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) 981 var_x = math_ops.reduce_mean( 982 math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), 983 reduce_axis, 984 keepdims=keepdims) 985 grad_y_offset = grad_y - mean_grad_y 986 x_offset = x - mean_x 987 mean = math_ops.reduce_mean( 988 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 989 grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( 990 grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) 991 grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( 992 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 993 if data_format == b"NCHW" or data_format == b"NCDHW": 994 grad_scale = array_ops.squeeze(grad_scale) 995 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 996 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 997 else: 998 if data_format == b"NHWC": 999 reduce_axis = [0, 1, 2] 1000 elif data_format == b"NDHWC": 1001 reduce_axis = [0, 1, 2, 3] 1002 elif data_format == b"NCHW": 1003 reduce_axis = [0, 2, 3] 1004 shape = [1, array_ops.size(pop_mean), 1, 1] 1005 pop_mean = array_ops.reshape(pop_mean, shape) 1006 pop_var = array_ops.reshape(pop_var, shape) 1007 scale = array_ops.reshape(scale, shape) 1008 else: 1009 reduce_axis = [0, 2, 3, 4] 1010 shape = [1, array_ops.size(pop_mean), 1, 1, 1] 1011 pop_mean = array_ops.reshape(pop_mean, shape) 1012 pop_var = array_ops.reshape(pop_var, shape) 1013 scale = array_ops.reshape(scale, shape) 1014 1015 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 1016 var_rsqrt = math_ops.rsqrt(pop_var + epsilon) 1017 grad_scale = math_ops.reduce_sum( 1018 grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) 1019 grad_x = grad_y * scale * var_rsqrt 1020 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 1021 1022 1023@ops.RegisterGradient("FusedBatchNormGrad") 1024def _FusedBatchNormGradGrad(op, *grad): 1025 """Returns the gradients for the 3 inputs of FusedBatchNormGrad. 1026 1027 Args: 1028 op: The FusedBatchNormGradOp for which we need to compute gradients. 1029 *grad: An argument list for tensors of gradients wrt the outputs with 1030 grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as 1031 grad_grad_offset. 1032 1033 Returns: 1034 A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y 1035 is the gradient for grad_y, grad_x the gradient for x, grad_scale the 1036 gradient for scale. 1037 """ 1038 data_format = op.get_attr("data_format") 1039 epsilon = op.get_attr("epsilon") 1040 is_training = op.get_attr("is_training") 1041 grad_y = op.inputs[0] 1042 x = op.inputs[1] 1043 scale = op.inputs[2] 1044 pop_mean = op.inputs[3] 1045 pop_var = op.inputs[4] 1046 grad_grad_x = grad[0] 1047 grad_grad_scale = grad[1] 1048 grad_grad_offset = grad[2] 1049 with backprop.GradientTape() as tape: 1050 tape.watch(grad_y) 1051 tape.watch(x) 1052 tape.watch(scale) 1053 grad_x, grad_scale, grad_offset = _BatchNormGrad( 1054 grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) 1055 grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset] 1056 grad_grad_y, grad_x, grad_scale = tape.gradient( 1057 [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial) 1058 return grad_grad_y, grad_x, grad_scale, None, None 1059 1060 1061@ops.RegisterGradient("FusedBatchNormGradV2") 1062def _FusedBatchNormGradGradV2(op, *grad): 1063 return _FusedBatchNormGradGrad(op, *grad) 1064 1065 1066@ops.RegisterGradient("FusedBatchNormGradV3") 1067def _FusedBatchNormGradGradV3(op, *grad): 1068 grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad) 1069 return grad_grad_y, grad_x, grad_scale, None, None, None 1070 1071 1072@ops.RegisterGradient("L2Loss") 1073def _L2LossGrad(op, grad): 1074 """Return the gradients for L2Loss. 1075 1076 Args: 1077 op: The L2LossOp for which we need to generate gradients. 1078 grad: Tensor containing a single number. 1079 1080 Returns: 1081 The gradient, which is (x * grad). 1082 """ 1083 return op.inputs[0] * grad 1084 1085 1086@ops.RegisterGradient("TopK") 1087@ops.RegisterGradient("TopKV2") 1088def _TopKGrad(op, grad, _): 1089 """Return the gradients for TopK. 1090 1091 Args: 1092 op: The TopKOp for which we need to generate gradients. 1093 grad: Tensor. The gradients passed to the TopKOp. 1094 1095 Returns: 1096 A list of two tensors, the first being the gradient w.r.t to the input and 1097 TopK, and the second being the gradient w.r.t. to the indices (all zero). 1098 """ 1099 in_shape = array_ops.shape(op.inputs[0]) 1100 ind_shape = array_ops.shape(op.outputs[1]) 1101 1102 # int32 is not supported on GPU hence up-casting 1103 ind_lastdim = array_ops.gather( 1104 math_ops.cast(ind_shape, dtypes.int64), 1105 array_ops.size(ind_shape) - 1) 1106 # Flatten indices to 2D. 1107 ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) 1108 1109 in_lastdim = array_ops.gather( 1110 math_ops.cast(in_shape, dtypes.int64), 1111 array_ops.size(in_shape) - 1) 1112 outerdim = array_ops.shape(ind_2d)[0] 1113 # Compute linear indices (flattened to 1D). 1114 ind = array_ops.reshape( 1115 ind_2d + math_ops.cast( 1116 array_ops.expand_dims( 1117 math_ops.range(0, 1118 math_ops.cast(outerdim, dtypes.int64) * in_lastdim, 1119 in_lastdim), -1), dtypes.int32), [-1]) 1120 1121 # Substitute grad to appropriate locations and fill the rest with zeros, 1122 # finally reshaping it to the original input shape. 1123 return [ 1124 array_ops.reshape( 1125 array_ops.scatter_nd( 1126 array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), 1127 [math_ops.reduce_prod(in_shape)]), in_shape), 1128 array_ops.zeros([], dtype=dtypes.int32) 1129 ] 1130 1131 1132@ops.RegisterGradient("ApproxTopK") 1133def _ApproxTopKGradient(op, grad, _): 1134 """Return the gradients for ApproxTopK. 1135 1136 Args: 1137 op: The ApproxTopK for which we need to generate gradients. 1138 grad: The gradients for backprop. 1139 1140 Returns: 1141 Scattered gradient based on the top-k indices. 1142 """ 1143 # The code below is to generate the correct index and value mapping for 1144 # scatter_nd to work properly. 1145 # 1146 # We use static evaluations as much as possible to reduce the runtime cost. 1147 # That's said, use operation.shape instead of array_ops.shape; 1148 # and use functools.reduce(operator.mul, ...) instead of math_ops.reduce_prod 1149 idx_shape = op.outputs[1].shape 1150 lifted_idx_shape = idx_shape + [1] 1151 flat_shape_len = functools.reduce(operator.mul, idx_shape) 1152 rank = idx_shape.rank 1153 reduction_dim = op.get_attr("reduction_dimension") 1154 if reduction_dim < 0: 1155 reduction_dim = rank + reduction_dim 1156 1157 def GetLiftedIdx(d): 1158 if d == reduction_dim: 1159 return array_ops.reshape(op.outputs[1], lifted_idx_shape) 1160 iota_len = idx_shape[d] 1161 iota_shape = list(itertools.repeat(1, rank + 1)) 1162 iota_shape[d] = iota_len 1163 iota = array_ops.reshape(math_ops.range(iota_len), iota_shape) 1164 return array_ops.broadcast_to(iota, lifted_idx_shape) 1165 1166 lifted_idx = array_ops.concat( 1167 list(GetLiftedIdx(d) for d in range(rank)), axis=rank) 1168 flat_idx = array_ops.reshape(lifted_idx, [flat_shape_len, rank]) 1169 flat_grad = array_ops.reshape(grad, [flat_shape_len]) 1170 return array_ops.scatter_nd(flat_idx, flat_grad, op.inputs[0].shape) 1171 1172 1173@ops.RegisterGradient("NthElement") 1174def _NthElementGrad(op, grad): 1175 """Return the gradients for NthElement. 1176 1177 Args: 1178 op: The NthElementOp for which we need to generate gradients. 1179 grad: Tensor. The gradients passed to the NthElementOp 1180 1181 Returns: 1182 A list of two tensors, the first being the gradient w.r.t. the input, 1183 the second being the gradient w.r.t. the N (None). 1184 """ 1185 input = op.inputs[0] # pylint: disable=redefined-builtin 1186 output = op.outputs[0] 1187 1188 # Compute the number of elements which equal to output in each reduction 1189 # dimension. If there are multiple elements then the gradient will be 1190 # divided between them. 1191 indicators = math_ops.cast( 1192 math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype) 1193 1194 grad = array_ops.expand_dims(grad, -1) 1195 num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) 1196 1197 return [math_ops.divide(indicators, num_selected) * grad, None] 1198 1199 1200def _MeanAggregator(inputs, segments): 1201 """Replaces each segment with its mean along the last axis. 1202 1203 Specifically, each value in the `inputs` tensor gets replaced by the mean 1204 value computed from the values that belong to the same segment. 1205 1206 Args: 1207 inputs: A 2-tensor. Aggregation is done over dimension 1. 1208 segments: A 2-tensor, same shape as `input`. 1209 1210 Returns: 1211 The result, same shape and type as `inputs`. 1212 """ 1213 result = [] 1214 for inputs_i, segments_i in zip( 1215 array_ops.split(inputs, inputs.shape[0]), 1216 array_ops.split(segments, segments.shape[0])): 1217 # Note that we do not use tf.math.segment_mean, as it has no TPU support. 1218 means_i = math_ops.unsorted_segment_mean( 1219 inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1) 1220 result.append( 1221 array_ops.reshape(array_ops.gather(means_i, segments_i), [-1])) 1222 return array_ops.stack(result, axis=0) 1223 1224 1225# We have to register the gradients for these ops so that tensorflow will know 1226# how to differentiate them. 1227@ops.RegisterGradient("IsotonicRegression") 1228def _IsotonicRegressionGrad(op, grad_output, grad_segments): 1229 """Gradient for the isotonic regression function. 1230 1231 Args: 1232 op: The IsotonicRegression tensorflow op. 1233 grad_output: Tensor of incoming gradients with respect to the output. 1234 grad_segments: Tensor of incoming gradients with respect to the segments. 1235 1236 Returns: 1237 A tensor, same size as `grad_output` with the gradient with respect to 1238 the input. 1239 """ 1240 del grad_segments # Discrete, non-differentiable. 1241 segments = op.outputs[1] 1242 return _MeanAggregator(grad_output, segments) 1243