1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in linalg_ops.py. 16 17Useful reference for derivative formulas is (Mike Giles, 2008). 18 19Ionescu et al. (2015) provide a detailed derivation of formulas for 20backpropagating through spectral layers (SVD and Eig). 21 22References: 23 An extended collection of matrix derivative results for 24 forward and reverse mode automatic differentiation: 25 [Mike Giles, 2008] 26 (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124) 27 ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf)) 28 Matrix Backpropagation for Deep Networks with Structured Layers 29 [Ionescu et al., 2015] 30 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html) 31 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf)) 32 Training Deep Networks with Structured Layers by Matrix Backpropagation: 33 [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838) 34 ([pdf](https://arxiv.org/pdf/1509.07838.pdf)) 35""" 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import gen_linalg_ops 41from tensorflow.python.ops import linalg_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops.linalg import linalg_impl as _linalg 44 45 46@ops.RegisterGradient("MatrixInverse") 47def _MatrixInverseGrad(op, grad): 48 """Gradient for MatrixInverse.""" 49 ainv = op.outputs[0] 50 op_adjoint = op.get_attr("adjoint") 51 return -math_ops.matmul( # pylint: disable=invalid-unary-operand-type 52 ainv, 53 math_ops.matmul(grad, ainv, adjoint_a=op_adjoint, 54 adjoint_b=not op_adjoint), 55 adjoint_a=not op_adjoint) 56 57 58@ops.RegisterGradient("Einsum") 59def _EinsumGrad(op, grad): 60 """Gradient for Einsum.""" 61 ellipsis = "..." 62 63 def _GetAxisFromLabel(subscripts, label): 64 """Returns the axis (possibly negative) corresponding to a label. 65 66 Returns the axis index of the axis label if it is before an ellipsis (or if 67 the ellipsis is not present), and the negative index if it occurs after the 68 ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`. 69 70 For multiple occurrences, returns the leftmost one. If not found, returns 71 None. 72 73 Args: 74 subscripts: A string denoting the einsum subscript (e.g. `ab...cd`) 75 label: The single character axis label. 76 """ 77 splits = subscripts.split(ellipsis) 78 index = splits[0].find(label) 79 if index != -1: 80 return index 81 if len(splits) < 2: 82 return None 83 index = splits[1].find(label) 84 if index != -1: 85 return index - len(splits[1]) 86 return None 87 88 def _GetBcastSubshape(subscripts): 89 """Returns a tuple denoting the slice mapping to ellipsis. 90 91 For a given subscript, returns a tuple (start, end) denoting the start 92 axis index and the (negative) end axis index respectively. For any input 93 Tensor `x` described by the subscript, `x[start:end]` would be the slice 94 represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`. 95 96 If ellipsis is not present in `subscripts`, returns `(0, 0)`. 97 98 Args: 99 subscripts: A string denoting the einsum subscript. 100 """ 101 start = subscripts.find(ellipsis) 102 if start == -1: 103 return 0, 0 104 remaining = len(subscripts) - (start + len(ellipsis)) 105 end = -remaining if remaining > 0 else None 106 return start, end 107 108 def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts): 109 """Returns reduced subscripts and their corresponding dimensions and axes. 110 111 Given a set of axis labels, returns their concatenated subscript, their 112 corresponding dimensions from input_shape, and their corresponding axes. 113 Note that the concatenated subscript `reduced_subs` may have axis labels 114 from `reduced_label_set` in any order. For example, for the reduced label 115 set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns 116 subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`. 117 118 Args: 119 reduced_label_set: Set of axis labels which appear in `subscripts`. 120 input_shape: A `Tensor` representing the shape of the einsum operand 121 corresponding to `subscripts`. 122 subscripts: A string denoting the einsum subscript. 123 124 Returns: 125 reduced_subs: Subscripts formed by a concatenation of labels in 126 `reduced_label_set`. 127 reduced_dims: Dimensions from `input_shape` corresponding to each label 128 in `reduced_subs`. 129 reduced_axes: Axes described by `subscripts` corresponding to each label 130 in `reduced_subs`. If there are multiple occurrences in `subscripts`, 131 we consider only the leftmost one. 132 133 """ 134 # Concatenate the sequence of reduced axis labels. 135 reduced_subs = "".join(list(reduced_label_set)) 136 # Get the axis (may be positive, negative or zero) for each of the reduced 137 # labels. If the same label appears multiple times, get the left-most axis. 138 reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs] 139 # Get the corresponding dimensions for each reduced axis. 140 reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes]) 141 return reduced_subs, reduced_dims, reduced_axes 142 143 def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, 144 reduced_label_set): 145 """Returns the gradient wrt input for a unary einsum with reductions. 146 147 Args: 148 output_grad: The gradient wrt the output of a unary einsum operation. 149 output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). 150 input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). 151 input_shape: A `Tensor` representing the shape of the input operand. 152 reduced_label_set: The set of axis labels appearing in `input_subs` but 153 not in `output_subs`. 154 """ 155 # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and 156 # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced 157 # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. 158 reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( 159 reduced_label_set, input_shape, input_subs) 160 # Whether either the input or the output subscripts have a repeated label. 161 # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". 162 has_repeated_labels = ( 163 len(set(input_subs)) + len(set(output_subs)) < 164 len(input_subs) + len(output_subs)) 165 # Compute the input subscripts without the reduced axis labels, e.g. "aac" 166 # for the equation "aabbcd->ca". 167 input_subs_without_reduced_labels = "".join( 168 [s for s in input_subs if s not in reduced_label_set]) 169 170 # The gradient wrt the input for the equation "abc->ac" (or, equivalently 171 # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times 172 # along axis 1, where label 'b' represents a dimension of size N. 173 # 174 # If we're not dealing with repeated labels, and the non-reduced labels 175 # doesn't need to be transposed, then just tiling is enough and there is no 176 # need to call another einsum. For example, tiling is sufficient for 177 # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or 178 # "abc->ca" (transpose), we'd need another einsum operation after tiling. 179 if (not has_repeated_labels and 180 input_subs_without_reduced_labels == output_subs): 181 # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. 182 # for the equation "abcd->ac" with input shape [2,5,3,4], we get the 183 # reduced shape [2,1,3,1]. 184 reduced_shape = math_ops.reduced_shape( 185 input_shape, ops.convert_to_tensor(reduced_axes)) 186 # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to 187 # the shape [2,5,3,4] results in the gradient wrt "abcd". 188 return array_ops.broadcast_to( 189 array_ops.reshape(output_grad, reduced_shape), input_shape) 190 191 # If we *do* have traces or transpose operations, then prepend the extra 192 # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd 193 # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". 194 # 195 # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. 196 # This is the shape of the intermediate "bdca". 197 grad_shape_with_reduced_labels = array_ops.concat( 198 [reduced_dims, array_ops.shape(output_grad)], axis=0) 199 # Obtain the output shape of the reduction-only equation "bdca->ca" as if 200 # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we 201 # just have to prepend that many 1s to the output shape. 202 reduced_shape = ( 203 array_ops.concat([ 204 array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), 205 array_ops.shape(output_grad) 206 ], 207 axis=0)) 208 # Compute the VJP for the intermediate (viz. "bdca->ca") for which 209 # broadcasting is sufficient. 210 broadcasted_grad = array_ops.broadcast_to( 211 array_ops.reshape(output_grad, reduced_shape), 212 grad_shape_with_reduced_labels) 213 # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use 214 # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") 215 # since the output axis labels now appear in the input subscripts. 216 return gen_linalg_ops.einsum([broadcasted_grad], 217 "{}->{}".format(reduced_subs + output_subs, 218 input_subs)) 219 220 def _GetGradWrt(output_grad, other_operand, input_shape, input_subs, 221 other_subs, output_subs): 222 """Returns the gradient wrt an input operand for a binary einsum. 223 224 This function does not handle (un)broadcasting. This must be done separately 225 on the returned gradient. 226 227 Args: 228 output_grad: The gradient wrt the output of a binary einsum operation. 229 other_operand: The complementary `Tensor` operand i.e. which is not the 230 input operand. 231 input_shape: A `Tensor` representing the shape of input operand. 232 input_subs: The subscripts of the input operand. 233 other_subs: The subscripts of the complementary operand. 234 output_subs: The output subscripts. 235 """ 236 # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y), 237 # where the equation involves only Tensor contractions, generalized traces 238 # and transposes, the input gradients are given by the vector-jacobian 239 # products (VJPs): 240 # 241 # grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z) 242 # grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z} 243 # 244 # where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs 245 # x and y and grad_wrt_z is the given gradient with respect to output z. 246 # 247 # Proof: For unary einsum equations involving only transpose ("ij->ji") and 248 # traces ("ii->i"), the linear mapping's Jacobian at input x is given 249 # by the function itself. We can verify that the linear map given by the 250 # VJP are einsums with the equations "ji->ij" and "i->ii" respectively, 251 # where the latter represents 'un-tracing', or filling the diagonal with 252 # the input axis and non-diagonal entries are zeros. 253 # Furthermore, recall that matrix multiplication, which is 254 # represented by the equation "ab,bc->ac", has its VJPs given by the 255 # einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example 256 # https://math.stackexchange.com/a/2755680). Combined with transposes and 257 # traces we can rewrite Tensor contractions as regular matrix 258 # multiplication. Since each of these operations have their VJPs described 259 # by einsums of the required pattern, the result follows. 260 # 261 # Accordingly, einsum operations except for those with reductions, e.g. 262 # "abc,cd->ad" have their VJPs defined by: 263 # "{output_subs},{other_subs}->{input_subs}". 264 # 265 # But if there is a reduction, this would lead to the equation "ad,cd->abc" 266 # which is invalid because the reduced axis label 'b' is present in the 267 # output but not in any of the inputs. Therefore, we compute the VJP in two 268 # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of 269 # "abc->ac" or, equivalently, reduce_sum(..., axis=1). 270 # 271 # Compute the set of input axis labels which doesn't appear in either the 272 # output subscripts or the other operand's subscript. E.g. the set {'b'} for 273 # the equation "abc,cd->ad". 274 reduced_label_set = set(input_subs).difference( 275 set(output_subs + other_subs + ".")) 276 # Obtain the input subscripts with the reduced axis labels removed. E.g. 277 # "ac" in the above example. 278 left_subs = "".join(s for s in input_subs if s not in reduced_label_set) 279 280 # Compute the gradient wrt the input, without accounting for the operation 281 # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad". 282 grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand], 283 "{},{}->{}".format( 284 output_subs, other_subs, 285 left_subs)) 286 # If the reduced_label_set is empty, then we already have the gradient 287 # wrt the input. 288 if not reduced_label_set: 289 return grad_reduced 290 # Otherwise, we currently have the gradient wrt the output of the reduction 291 # operation "abc->ac". Invoke the subroutine for the gradient for unary 292 # einsum with reductions. 293 return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, 294 reduced_label_set) 295 296 equation = op.get_attr("equation") 297 if isinstance(equation, bytes): 298 equation = equation.decode() 299 input_subs, output_subs = equation.split("->") 300 301 if len(op.inputs) == 1: 302 # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the 303 # input (VJP) is given by the reversed equation: 304 # grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z) 305 # (See the justification in _GetGradWrt). This is valid unless there are 306 # reduced axis labels; i.e. axis labels appearing in the input but not in 307 # the output subscripts. 308 input_shape = array_ops.shape(op.inputs[0]) 309 # Find the axis labels which appear only in the input. 310 reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis)) 311 if not reduced_label_set: 312 # Return the einsum given by the reversed equation, since we don't have 313 # reduced axes. 314 return gen_linalg_ops.einsum([grad], 315 "{}->{}".format(output_subs, input_subs)) 316 # We do have reduced axes, so we invoke the subroutine for reduced unary 317 # einsums. 318 return _GetGradReduced(grad, output_subs, input_subs, input_shape, 319 reduced_label_set) 320 321 x_subs, y_subs = input_subs.split(",") 322 # Add ellipsis for broadcasted dimensions if any operand does not have it. 323 # This is because the equation "...ij,jk->ik" may be valid if the 0th input's 324 # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid 325 # because only the output subscripts contain ellipsis. 326 if ellipsis in output_subs: 327 if ellipsis not in x_subs: 328 x_subs += ellipsis 329 if ellipsis not in y_subs: 330 y_subs += ellipsis 331 332 # Obtain the gradients wrt the inputs x and y, without taking into account 333 # the unbroadcasting. 334 x, y = op.inputs[0], op.inputs[1] 335 if grad.dtype.is_complex: 336 x = math_ops.conj(x) 337 y = math_ops.conj(y) 338 339 x_shape = array_ops.shape(x) 340 y_shape = array_ops.shape(y) 341 grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs) 342 grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs) 343 344 if ellipsis not in output_subs: 345 # If no ellipsis in the output; then no need to unbroadcast. 346 return grad_x, grad_y 347 348 # Below we handle the case that broadcasting between x and y was necessary, 349 # with x and y having possibly different batch shapes. 350 351 # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c' 352 # and shape of rank 10; the range [3:-1] denotes the broadcasted axes. 353 bx_start, bx_end = _GetBcastSubshape(x_subs) 354 by_start, by_end = _GetBcastSubshape(y_subs) 355 # If the static batch shapes are equal, we don't need to unbroadcast. 356 x_shape_static = x.get_shape() 357 y_shape_static = y.get_shape() 358 if (x_shape_static.is_fully_defined() and 359 y_shape_static.is_fully_defined() and 360 x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]): 361 return grad_x, grad_y 362 363 # Sum the gradient across the broadcasted axes. 364 rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end], 365 y_shape[by_start:by_end]) 366 grad_x = array_ops.reshape( 367 math_ops.reduce_sum(grad_x, bx_start + rx), x_shape) 368 grad_y = array_ops.reshape( 369 math_ops.reduce_sum(grad_y, by_start + ry), y_shape) 370 return grad_x, grad_y 371 372 373@ops.RegisterGradient("MatrixDeterminant") 374def _MatrixDeterminantGrad(op, grad): 375 """Gradient for MatrixDeterminant.""" 376 a = op.inputs[0] 377 c = op.outputs[0] 378 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 379 multipliers = array_ops.reshape(grad * c, 380 array_ops.concat([array_ops.shape(c), [1, 1]], 381 0)) 382 return multipliers * a_adj_inv 383 384 385@ops.RegisterGradient("MatrixSquareRoot") 386def _MatrixSquareRootGrad(op, grad): 387 """Gradient for MatrixSquareRoot.""" 388 389 # Let A be an m x m square matrix (or batch of matrices) 390 # Let R = sqrtm(A) 391 # By definition, A = RR 392 # Take the differential: dA = d(RR) = RdR + dRR 393 # Solve the resulting Sylvester equation for dR 394 395 # Used to find Kronecker products within the Sylvester equation 396 def _KroneckerProduct(b1, b2): 397 """Computes the Kronecker product of two batches of square matrices.""" 398 b1_shape = array_ops.shape(b1) 399 b2_shape = array_ops.shape(b2) 400 b1_order = b1_shape[-1] 401 b2_order = b2_shape[-1] 402 403 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] 404 shape_slice = array_ops.slice(b1_shape, [0], 405 shape_slice_size) # Same for both batches 406 b1_reshape_shape = array_ops.concat( 407 [shape_slice, [b1_order], [1], [b1_order], [1]], 0) 408 b2_reshape_shape = array_ops.concat( 409 [shape_slice, [1], [b2_order], [1], [b2_order]], 0) 410 411 b1_reshape = array_ops.reshape(b1, b1_reshape_shape) 412 b2_reshape = array_ops.reshape(b2, b2_reshape_shape) 413 414 order_prod = b1_order * b2_order 415 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) 416 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) 417 418 sqrtm = op.outputs[0] # R 419 shape = array_ops.shape(sqrtm) 420 order = shape[-1] # m 421 matrix_count = math_ops.reduce_prod(shape[0:-2]) 422 423 # Get batch of m x m identity matrices 424 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix 425 eye_flat = array_ops.reshape(eye, [-1]) 426 eye_tiled = array_ops.tile(eye_flat, [matrix_count]) 427 eye_batch = array_ops.reshape(eye_tiled, shape) 428 429 # The transpose of R is taken in the k1 term instead of k2 in 430 # order to prevent redundant transposition of R (i.e. (R')' = R) 431 sqrtm_transpose = array_ops.matrix_transpose(sqrtm) 432 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) 433 k2 = _KroneckerProduct(sqrtm, eye_batch) 434 ksum = math_ops.add(k1, k2) 435 436 # Vectorize dA 437 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] 438 shape_slice = array_ops.slice(shape, [0], shape_slice_size) 439 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) 440 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) 441 442 # Solve for vec(dR) 443 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) 444 445 # Solve for dR by inverse vectorizing vec(dR) 446 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) 447 return array_ops.matrix_transpose(dsqrtm_transpose) 448 449 450@ops.RegisterGradient("LogMatrixDeterminant") 451def _LogMatrixDeterminantGrad(op, _, grad_b): 452 """Gradient for LogMatrixDeterminant.""" 453 a = op.inputs[0] 454 c = op.outputs[1] 455 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 456 multipliers = array_ops.reshape( 457 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) 458 return multipliers * a_adj_inv 459 460 461@ops.RegisterGradient("Cholesky") 462def _CholeskyGrad(op, grad): 463 """Gradient for Cholesky.""" 464 465 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 466 l = op.outputs[0] 467 num_rows = array_ops.shape(l)[-1] 468 batch_shape = array_ops.shape(l)[:-2] 469 l_inverse = linalg_ops.matrix_triangular_solve(l, 470 linalg_ops.eye( 471 num_rows, 472 batch_shape=batch_shape, 473 dtype=l.dtype)) 474 475 middle = math_ops.matmul(l, grad, adjoint_a=True) 476 middle = array_ops.matrix_set_diag(middle, 477 0.5 * array_ops.matrix_diag_part(middle)) 478 middle = array_ops.matrix_band_part(middle, -1, 0) 479 480 grad_a = math_ops.matmul( 481 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 482 483 grad_a += _linalg.adjoint(grad_a) 484 return grad_a * 0.5 485 486 487@ops.RegisterGradient("Qr") 488def _QrGrad(op, dq, dr): 489 """Gradient for Qr.""" 490 491 # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 492 # QR and LQ Decomposition Matrix Backpropagation Algorithms for 493 # Square, Wide, and Deep, Real and Complex, Matrices and Their Software 494 # Implementation 495 q, r = op.outputs 496 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 497 r.shape.as_list()[-1] is None): 498 raise NotImplementedError("QrGrad not implemented with dynamic shapes. " 499 f"Received r.shape: {r.shape}") 500 if (r.shape.dims[-2].value > r.shape.dims[-1].value and 501 q.shape.dims[-2].value == q.shape.dims[-1].value): 502 raise NotImplementedError("QrGrad not implemented when nrows > ncols " 503 "and full_matrices is true. Received r.shape=" 504 f"{r.shape} with nrows={r.shape.dims[-2]}" 505 f"and ncols={r.shape.dims[-1]}.") 506 507 def _TriangularSolve(x, r): 508 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 509 return _linalg.adjoint( 510 linalg_ops.matrix_triangular_solve( 511 r, _linalg.adjoint(x), lower=False, adjoint=False)) 512 513 def _QrGradSquareAndDeepMatrices(q, r, dq, dr): 514 """Gradient for matrix orders num_rows >= num_cols 515 and full_matrices is false. 516 """ 517 qdq = math_ops.matmul(q, dq, adjoint_a=True) 518 qdq_ = qdq - _linalg.adjoint(qdq) 519 rdr = math_ops.matmul(r, dr, adjoint_b=True) 520 rdr_ = rdr - _linalg.adjoint(rdr) 521 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 522 523 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 524 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 525 ret = grad_a + grad_b 526 527 if q.dtype.is_complex: 528 # need to add a correction to the gradient formula for complex case 529 m = rdr - _linalg.adjoint(qdq) 530 eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m)) 531 correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype) 532 ret = ret + _TriangularSolve( 533 math_ops.matmul(q, _linalg.adjoint(correction)), r) 534 535 return ret 536 537 num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1] 538 539 if num_rows >= num_cols: 540 return _QrGradSquareAndDeepMatrices(q, r, dq, dr) 541 542 # Partition a = [x, y], r = [u, v] and reduce to the square case 543 a = op.inputs[0] 544 y = a[..., :, num_rows:] 545 u = r[..., :, :num_rows] 546 dv = dr[..., :, num_rows:] 547 du = dr[..., :, :num_rows] 548 dy = math_ops.matmul(q, dv) 549 dx = _QrGradSquareAndDeepMatrices(q, u, 550 dq + math_ops.matmul(y, dv, adjoint_b=True), 551 du) 552 return array_ops.concat([dx, dy], axis=-1) 553 554 555@ops.RegisterGradient("MatrixSolve") 556def _MatrixSolveGrad(op, grad): 557 """Gradient for MatrixSolve.""" 558 a = op.inputs[0] 559 adjoint_a = op.get_attr("adjoint") 560 c = op.outputs[0] 561 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 562 if adjoint_a: 563 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 564 else: 565 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 566 return (grad_a, grad_b) 567 568 569@ops.RegisterGradient("MatrixSolveLs") 570def _MatrixSolveLsGrad(op, grad): 571 """Gradients for MatrixSolveLs.""" 572 573 # TODO(rmlarsen): The implementation could be more efficient: 574 # a) Output the Cholesky factorization from forward op instead of 575 # recomputing it here. 576 # b) Implement a symmetric rank-k update op instead of computing 577 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 578 579 def _Overdetermined(op, grad): 580 """Gradients for the overdetermined case of MatrixSolveLs. 581 582 This is the backprop for the solution to the normal equations of the first 583 kind: 584 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 585 which solve the least squares problem 586 min ||A * X - B||_F^2 + lambda ||X||_F^2. 587 """ 588 a = op.inputs[0] 589 b = op.inputs[1] 590 x = op.outputs[0] 591 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 592 # pylint: disable=protected-access 593 chol = linalg_ops._RegularizedGramianCholesky( 594 a, l2_regularizer=l2_regularizer, first_kind=True) 595 # pylint: enable=protected-access 596 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 597 z = linalg_ops.cholesky_solve(chol, grad) 598 xzt = math_ops.matmul(x, z, adjoint_b=True) 599 zx_sym = xzt + array_ops.matrix_transpose(xzt) 600 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 601 grad_b = math_ops.matmul(a, z) 602 return (grad_a, grad_b, None) 603 604 def _Underdetermined(op, grad): 605 """Gradients for the underdetermined case of MatrixSolveLs. 606 607 This is the backprop for the solution to the normal equations of the second 608 kind: 609 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 610 that (for lambda=0) solve the least squares problem 611 min ||X||_F subject to A*X = B. 612 """ 613 a = op.inputs[0] 614 b = op.inputs[1] 615 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 616 # pylint: disable=protected-access 617 chol = linalg_ops._RegularizedGramianCholesky( 618 a, l2_regularizer=l2_regularizer, first_kind=False) 619 # pylint: enable=protected-access 620 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 621 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 622 tmp = linalg_ops.cholesky_solve(chol, b) 623 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 624 a1 = -math_ops.matmul(grad_b, a1) # pylint: disable=invalid-unary-operand-type 625 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 626 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 627 grad_a = a1 + a2 628 return (grad_a, grad_b, None) 629 630 fast = op.get_attr("fast") 631 if fast is False: 632 raise ValueError("Gradient not defined for fast=False") 633 matrix_shape = op.inputs[0].get_shape()[-2:] 634 if matrix_shape.is_fully_defined(): 635 if matrix_shape[-2] >= matrix_shape[-1]: 636 return _Overdetermined(op, grad) 637 else: 638 return _Underdetermined(op, grad) 639 else: 640 # We have to defer determining the shape to runtime and use 641 # conditional execution of the appropriate graph. 642 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 643 return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], 644 lambda: _Overdetermined(op, grad), 645 lambda: _Underdetermined(op, grad)) 646 647 648@ops.RegisterGradient("BandedTriangularSolve") 649def _BandedTriangularSolveGrad(op, grad): 650 """Gradient for BandedTriangularSolve.""" 651 a = op.inputs[0] 652 b = op.inputs[1] 653 num_bands = array_ops.shape(a)[-2] 654 adjoint_a = op.get_attr("adjoint") 655 lower_a = op.get_attr("lower") 656 c = op.outputs[0] 657 grad_b = linalg_ops.banded_triangular_solve( 658 a, grad, lower=lower_a, adjoint=not adjoint_a) 659 if adjoint_a: 660 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 661 else: 662 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 663 if lower_a: 664 grad_a = array_ops.matrix_diag_part( 665 grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT") 666 else: 667 grad_a = array_ops.matrix_diag_part( 668 grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT") 669 # If the static batch shapes are equal, we don't need to unbroadcast. 670 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 671 a.shape[:-2] == b.shape[:-2]): 672 return grad_a, grad_b 673 a_shape = array_ops.shape(a) 674 b_shape = array_ops.shape(b) 675 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 676 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 677 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 678 return grad_a, grad_b 679 680 681@ops.RegisterGradient("MatrixTriangularSolve") 682def _MatrixTriangularSolveGrad(op, grad): 683 """Gradient for MatrixTriangularSolve.""" 684 a = op.inputs[0] 685 b = op.inputs[1] 686 adjoint_a = op.get_attr("adjoint") 687 lower_a = op.get_attr("lower") 688 c = op.outputs[0] 689 grad_b = linalg_ops.matrix_triangular_solve( 690 a, grad, lower=lower_a, adjoint=not adjoint_a) 691 if adjoint_a: 692 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 693 else: 694 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) # pylint: disable=invalid-unary-operand-type 695 if lower_a: 696 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 697 else: 698 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 699 # If the static batch shapes are equal, we don't need to unbroadcast. 700 if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and 701 a.shape[:-2] == b.shape[:-2]): 702 return grad_a, grad_b 703 a_shape = array_ops.shape(a) 704 b_shape = array_ops.shape(b) 705 ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2]) 706 grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape) 707 grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape) 708 return grad_a, grad_b 709 710 711# To avoid nan in cases with degenerate eigenvalues or 712# degenerate/zero singular values in calculations of 713# f and s_inv_mat, we introduce a Lorentz broadening. 714def _SafeReciprocal(x, epsilon=1E-20): 715 return x * math_ops.reciprocal(x * x + epsilon) 716 717 718@ops.RegisterGradient("Eig") 719def _EigGrad(op, grad_e, grad_v): 720 """Gradient for Eig. 721 722 Based on eq. 4.77 from paper by 723 Christoph Boeddeker et al. 724 https://arxiv.org/abs/1701.00392 725 See also 726 "Computation of eigenvalue and eigenvector derivatives 727 for a general complex-valued eigensystem" by Nico van der Aa. 728 As for now only distinct eigenvalue case is considered. 729 """ 730 e = op.outputs[0] 731 compute_v = op.get_attr("compute_v") 732 # a = op.inputs[0], which satisfies 733 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 734 with ops.control_dependencies([grad_e, grad_v]): 735 if compute_v: 736 v = op.outputs[1] 737 vt = _linalg.adjoint(v) 738 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 739 # Notice that because of the term involving f, the gradient becomes 740 # infinite (or NaN in practice) when eigenvalues are not unique. 741 # Mathematically this should not be surprising, since for (k-fold) 742 # degenerate eigenvalues, the corresponding eigenvectors are only defined 743 # up to arbitrary rotation in a (k-dimensional) subspace. 744 f = array_ops.matrix_set_diag( 745 _SafeReciprocal( 746 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 747 array_ops.zeros_like(e)) 748 f = math_ops.conj(f) 749 vgv = math_ops.matmul(vt, grad_v) 750 mid = array_ops.matrix_diag(grad_e) 751 diag_grad_part = array_ops.matrix_diag( 752 array_ops.matrix_diag_part( 753 math_ops.cast(math_ops.real(vgv), vgv.dtype))) 754 mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part)) 755 # vt is formally invertible as long as the original matrix is 756 # diagonalizable. However, in practice, vt may 757 # be ill-conditioned when matrix original matrix is close to 758 # non-diagonalizable one 759 grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt)) 760 else: 761 _, v = linalg_ops.eig(op.inputs[0]) 762 vt = _linalg.adjoint(v) 763 # vt is formally invertible as long as the original matrix is 764 # diagonalizable. However, in practice, vt may 765 # be ill-conditioned when matrix original matrix is close to 766 # non-diagonalizable one 767 grad_a = linalg_ops.matrix_solve( 768 vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt)) 769 return math_ops.cast(grad_a, op.inputs[0].dtype) 770 771 772@ops.RegisterGradient("SelfAdjointEigV2") 773def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 774 """Gradient for SelfAdjointEigV2.""" 775 e = op.outputs[0] 776 compute_v = op.get_attr("compute_v") 777 # a = op.inputs[0], which satisfies 778 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 779 with ops.control_dependencies([grad_e, grad_v]): 780 if compute_v: 781 v = op.outputs[1] 782 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 783 # Notice that because of the term involving f, the gradient becomes 784 # infinite (or NaN in practice) when eigenvalues are not unique. 785 # Mathematically this should not be surprising, since for (k-fold) 786 # degenerate eigenvalues, the corresponding eigenvectors are only defined 787 # up to arbitrary rotation in a (k-dimensional) subspace. 788 f = array_ops.matrix_set_diag( 789 _SafeReciprocal( 790 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 791 array_ops.zeros_like(e)) 792 grad_a = math_ops.matmul( 793 v, 794 math_ops.matmul( 795 array_ops.matrix_diag(grad_e) + 796 f * math_ops.matmul(v, grad_v, adjoint_a=True), 797 v, 798 adjoint_b=True)) 799 else: 800 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 801 grad_a = math_ops.matmul(v, 802 math_ops.matmul( 803 array_ops.matrix_diag(grad_e), 804 v, 805 adjoint_b=True)) 806 # The forward op only depends on the lower triangular part of a, so here we 807 # symmetrize and take the lower triangle 808 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 809 grad_a = array_ops.matrix_set_diag(grad_a, 810 0.5 * array_ops.matrix_diag_part(grad_a)) 811 return grad_a 812 813 814@ops.RegisterGradient("Svd") 815def _SvdGrad(op, grad_s, grad_u, grad_v): 816 """Gradient for the singular value decomposition.""" 817 818 # The derivation for the compute_uv=False case, and most of 819 # the derivation for the full_matrices=True case, are in 820 # Giles' paper (see reference at top of file). A derivation for 821 # the full_matrices=False case is available at 822 # https://j-towns.github.io/papers/svd-derivative.pdf 823 # The derivation for complex valued SVD can be found in 824 # https://re-ra.xyz/misc/complexsvd.pdf or 825 # https://giggleliu.github.io/2019/04/02/einsumbp.html 826 a = op.inputs[0] 827 a_shape = a.get_shape().with_rank_at_least(2) 828 grad_s = math_ops.cast(grad_s, a.dtype) 829 grad_s_mat = array_ops.matrix_diag(grad_s) 830 831 if not op.get_attr("compute_uv"): 832 s, u, v = linalg_ops.svd(a, compute_uv=True) 833 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 834 grad_a.set_shape(a_shape) 835 return grad_a 836 837 full_matrices = op.get_attr("full_matrices") 838 839 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 840 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 841 m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) 842 n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) 843 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 844 grad_v_shape[:-2]) 845 a_shape = batch_shape.concatenate([m, n]) 846 847 m = a_shape.dims[-2].value 848 n = a_shape.dims[-1].value 849 # TODO(rmlarsen): Make this work with placeholders. 850 if m is None or n is None: 851 raise NotImplementedError( 852 "SVD gradient has not been implemented for input with unknown " 853 "inner matrix shape.") 854 855 s = op.outputs[0] 856 u = op.outputs[1] 857 v = op.outputs[2] 858 s = math_ops.cast(s, a.dtype) 859 860 use_adjoint = False 861 if m > n: 862 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 863 # Hermitian transpose of the gradient at the end. 864 use_adjoint = True 865 m, n = n, m 866 u, v = v, u 867 grad_u, grad_v = grad_v, grad_u 868 869 with ops.control_dependencies([grad_s, grad_u, grad_v]): 870 if full_matrices and abs(m - n) > 1: 871 raise NotImplementedError( 872 "svd gradient is not implemented for abs(m - n) > 1 " 873 f"when full_matrices is True. Received: m={m} and n={n} from " 874 f"op input={a} with shape={a_shape}.") 875 s_mat = array_ops.matrix_diag(s) 876 s2 = math_ops.square(s) 877 878 # NOTICE: Because of the term involving f, the gradient becomes 879 # infinite (or NaN in practice) when singular values are not unique. 880 # Mathematically this should not be surprising, since for (k-fold) 881 # degenerate singular values, the corresponding singular vectors are 882 # only defined up a (k-dimensional) subspace. In practice, this can 883 # lead to numerical instability when singular values are close but not 884 # exactly equal. 885 886 s_shape = array_ops.shape(s) 887 f = array_ops.matrix_set_diag( 888 _SafeReciprocal( 889 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), 890 array_ops.zeros_like(s)) 891 s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s)) 892 893 v1 = v[..., :, :m] 894 grad_v1 = grad_v[..., :, :m] 895 896 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 897 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 898 899 f_u = f * u_gu 900 f_v = f * v_gv 901 902 term1_nouv = ( 903 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 904 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 905 906 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 907 908 if m == n: 909 grad_a_before_transpose = term1 910 else: 911 gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True) 912 gv1t_v1 = math_ops.matmul(gv1t, v1) 913 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 914 915 if full_matrices: 916 v2 = v[..., :, m:n] 917 grad_v2 = grad_v[..., :, m:n] 918 919 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 920 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 921 922 u_s_inv = math_ops.matmul(u, s_inv_mat) 923 term2 = math_ops.matmul(u_s_inv, term2_nous) 924 925 grad_a_before_transpose = term1 + term2 926 927 if a.dtype.is_complex: 928 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype) 929 l = eye * v_gv 930 term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l) 931 term3 = 1 / 2. * math_ops.matmul( 932 u, math_ops.matmul(term3_nouv, v1, adjoint_b=True)) 933 934 grad_a_before_transpose += term3 935 936 if use_adjoint: 937 grad_a = array_ops.matrix_transpose( 938 grad_a_before_transpose, conjugate=True) 939 else: 940 grad_a = grad_a_before_transpose 941 942 grad_a.set_shape(a_shape) 943 return grad_a 944 945 946def _LeftShift(x): 947 """Shifts next-to-last dimension to the left, adding zero on the right.""" 948 rank = array_ops.rank(x) 949 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 950 pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0) 951 return array_ops.pad(x[..., 1:, :], pad) 952 953 954def _RightShift(x): 955 """Shifts next-to-last dimension to the right, adding zero on the left.""" 956 rank = array_ops.rank(x) 957 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 958 pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0) 959 return array_ops.pad(x[..., :-1, :], pad) 960 961 962@ops.RegisterGradient("TridiagonalMatMul") 963def _TridiagonalMatMulGrad(op, grad): 964 """Gradient for TridiagonalMatMul.""" 965 superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True) 966 maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True) 967 subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True) 968 rhs_conj = math_ops.conj(op.inputs[3]) 969 970 superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1) 971 maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1) 972 subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1) 973 rhs_grad = _RightShift(superdiag_conj * grad) + \ 974 maindiag_conj * grad + _LeftShift(subdiag_conj * grad) 975 976 superdiag_grad = array_ops.expand_dims(superdiag_grad, -2) 977 maindiag_grad = array_ops.expand_dims(maindiag_grad, -2) 978 subdiag_grad = array_ops.expand_dims(subdiag_grad, -2) 979 980 return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad 981 982 983@ops.RegisterGradient("TridiagonalSolve") 984def _TridiagonalSolveGrad(op, grad): 985 """Gradient for TridiagonalSolveGrad.""" 986 diags = op.inputs[0] 987 x = op.outputs[0] 988 partial_pivoting = op.get_attr("partial_pivoting") 989 perturb_singular = op.get_attr("perturb_singular") 990 991 # Transposing the matrix within tridiagonal_solve kernel by interchanging 992 # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with 993 # paddings required by cusparse*gtsv routines. 994 # So constructing the transposed matrix in Python. 995 diags_transposed = _TransposeTridiagonalMatrix(diags) 996 997 grad_rhs = linalg_ops.tridiagonal_solve( 998 diags_transposed, 999 grad, 1000 partial_pivoting=partial_pivoting, 1001 perturb_singular=perturb_singular) 1002 grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x) # pylint: disable=invalid-unary-operand-type 1003 return grad_diags, grad_rhs 1004 1005 1006def _TransposeTridiagonalMatrix(diags): 1007 """Transposes a tridiagonal matrix. 1008 1009 Args: 1010 diags: the diagonals of the input matrix in the compact form (see 1011 linalg_ops.tridiagonal_solve). 1012 1013 Returns: 1014 Diagonals of the transposed matrix in the compact form. 1015 """ 1016 1017 diag = diags[..., 1, :] 1018 1019 if diags.shape.is_fully_defined(): 1020 # For fully defined tensor we can concat with a tensor of zeros, which is 1021 # faster than using array_ops.pad(). 1022 zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype) 1023 superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1) 1024 subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1) 1025 else: 1026 rank = array_ops.rank(diags) 1027 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1028 superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])), 1029 axis=0) 1030 superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad) 1031 subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])), 1032 axis=0) 1033 subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad) 1034 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1035 1036 1037def _MatmulExtractingThreeDiagonals(x, y_tr): 1038 """Multiplies matrices and extracts three diagonals from the product. 1039 1040 With sizes M x K and K x M, this function takes O(MK) time and O(M) space, 1041 while using math_ops.matmul, and then extracting the diagonals would take 1042 O(M^2 K) time and O(M^2) space. 1043 1044 Args: 1045 x: first matrix 1046 y_tr: second matrix transposed 1047 1048 Returns: 1049 Diagonals of the product in compact format (see 1050 linalg_ops.tridiagonal_solve) 1051 1052 """ 1053 diag = math_ops.reduce_sum(x * y_tr, axis=-1) 1054 1055 if y_tr.shape.is_fully_defined(): 1056 zeros = array_ops.zeros( 1057 list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype) 1058 superdiag = math_ops.reduce_sum( 1059 x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1) 1060 subdiag = math_ops.reduce_sum( 1061 x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1) 1062 else: 1063 rank = array_ops.rank(y_tr) 1064 zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32) 1065 superdiag_pad = array_ops.concat( 1066 (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0) 1067 superdiag = math_ops.reduce_sum( 1068 x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1) 1069 subdiag_pad = array_ops.concat( 1070 (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0) 1071 subdiag = math_ops.reduce_sum( 1072 x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1) 1073 return array_ops.stack([superdiag, diag, subdiag], axis=-2) 1074