xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/linalg_grad.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in linalg_ops.py.
16
17Useful reference for derivative formulas is (Mike Giles, 2008).
18
19Ionescu et al. (2015) provide a detailed derivation of formulas for
20backpropagating through spectral layers (SVD and Eig).
21
22References:
23  An extended collection of matrix derivative results for
24  forward and reverse mode automatic differentiation:
25    [Mike Giles, 2008]
26    (https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124)
27    ([pdf](http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf))
28  Matrix Backpropagation for Deep Networks with Structured Layers
29    [Ionescu et al., 2015]
30    (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.html)
31    ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Ionescu_Matrix_Backpropagation_for_ICCV_2015_paper.pdf))
32  Training Deep Networks with Structured Layers by Matrix Backpropagation:
33    [Ionescu et al., 2015](https://arxiv.org/abs/1509.07838)
34    ([pdf](https://arxiv.org/pdf/1509.07838.pdf))
35"""
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import gen_linalg_ops
41from tensorflow.python.ops import linalg_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops.linalg import linalg_impl as _linalg
44
45
46@ops.RegisterGradient("MatrixInverse")
47def _MatrixInverseGrad(op, grad):
48  """Gradient for MatrixInverse."""
49  ainv = op.outputs[0]
50  op_adjoint = op.get_attr("adjoint")
51  return -math_ops.matmul(  # pylint: disable=invalid-unary-operand-type
52      ainv,
53      math_ops.matmul(grad, ainv, adjoint_a=op_adjoint,
54                      adjoint_b=not op_adjoint),
55      adjoint_a=not op_adjoint)
56
57
58@ops.RegisterGradient("Einsum")
59def _EinsumGrad(op, grad):
60  """Gradient for Einsum."""
61  ellipsis = "..."
62
63  def _GetAxisFromLabel(subscripts, label):
64    """Returns the axis (possibly negative) corresponding to a label.
65
66    Returns the axis index of the axis label if it is before an ellipsis (or if
67    the ellipsis is not present), and the negative index if it occurs after the
68    ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
69
70    For multiple occurrences, returns the leftmost one. If not found, returns
71    None.
72
73    Args:
74      subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
75      label: The single character axis label.
76    """
77    splits = subscripts.split(ellipsis)
78    index = splits[0].find(label)
79    if index != -1:
80      return index
81    if len(splits) < 2:
82      return None
83    index = splits[1].find(label)
84    if index != -1:
85      return index - len(splits[1])
86    return None
87
88  def _GetBcastSubshape(subscripts):
89    """Returns a tuple denoting the slice mapping to ellipsis.
90
91    For a given subscript, returns a tuple (start, end) denoting the start
92    axis index and the (negative) end axis index respectively. For any input
93    Tensor `x` described by the subscript, `x[start:end]` would be the slice
94    represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
95
96    If ellipsis is not present in `subscripts`, returns `(0, 0)`.
97
98    Args:
99      subscripts: A string denoting the einsum subscript.
100    """
101    start = subscripts.find(ellipsis)
102    if start == -1:
103      return 0, 0
104    remaining = len(subscripts) - (start + len(ellipsis))
105    end = -remaining if remaining > 0 else None
106    return start, end
107
108  def _GetReducedSubscripts(reduced_label_set, input_shape, subscripts):
109    """Returns reduced subscripts and their corresponding dimensions and axes.
110
111    Given a set of axis labels, returns their concatenated subscript, their
112    corresponding dimensions from input_shape, and their corresponding axes.
113    Note that the concatenated subscript `reduced_subs` may have axis labels
114    from `reduced_label_set` in any order. For example, for the reduced label
115    set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
116    subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
117
118    Args:
119      reduced_label_set: Set of axis labels which appear in `subscripts`.
120      input_shape: A `Tensor` representing the shape of the einsum operand
121        corresponding to `subscripts`.
122      subscripts: A string denoting the einsum subscript.
123
124    Returns:
125      reduced_subs: Subscripts formed by a concatenation of labels in
126        `reduced_label_set`.
127      reduced_dims: Dimensions from `input_shape` corresponding to each label
128        in `reduced_subs`.
129      reduced_axes: Axes described by `subscripts` corresponding to each label
130        in `reduced_subs`. If there are multiple occurrences in `subscripts`,
131        we consider only the leftmost one.
132
133    """
134    # Concatenate the sequence of reduced axis labels.
135    reduced_subs = "".join(list(reduced_label_set))
136    # Get the axis (may be positive, negative or zero) for each of the reduced
137    # labels. If the same label appears multiple times, get the left-most axis.
138    reduced_axes = [_GetAxisFromLabel(subscripts, s) for s in reduced_subs]
139    # Get the corresponding dimensions for each reduced axis.
140    reduced_dims = array_ops.stack([input_shape[ax] for ax in reduced_axes])
141    return reduced_subs, reduced_dims, reduced_axes
142
143  def _GetGradReduced(output_grad, output_subs, input_subs, input_shape,
144                      reduced_label_set):
145    """Returns the gradient wrt input for a unary einsum with reductions.
146
147    Args:
148      output_grad: The gradient wrt the output of a unary einsum operation.
149      output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
150      input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
151      input_shape: A `Tensor` representing the shape of the input operand.
152      reduced_label_set: The set of axis labels appearing in `input_subs` but
153        not in `output_subs`.
154    """
155    # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
156    # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
157    # subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
158    reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts(
159        reduced_label_set, input_shape, input_subs)
160    # Whether either the input or the output subscripts have a repeated label.
161    # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
162    has_repeated_labels = (
163        len(set(input_subs)) + len(set(output_subs)) <
164        len(input_subs) + len(output_subs))
165    # Compute the input subscripts without the reduced axis labels, e.g. "aac"
166    # for the equation "aabbcd->ca".
167    input_subs_without_reduced_labels = "".join(
168        [s for s in input_subs if s not in reduced_label_set])
169
170    # The gradient wrt the input for the equation "abc->ac" (or, equivalently
171    # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
172    # along axis 1, where label 'b' represents a dimension of size N.
173    #
174    # If we're not dealing with repeated labels, and the non-reduced labels
175    # doesn't need to be transposed, then just tiling is enough and there is no
176    # need to call another einsum. For example, tiling is sufficient for
177    # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
178    # "abc->ca" (transpose), we'd need another einsum operation after tiling.
179    if (not has_repeated_labels and
180        input_subs_without_reduced_labels == output_subs):
181      # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
182      # for the equation "abcd->ac" with input shape [2,5,3,4], we get the
183      # reduced shape [2,1,3,1].
184      reduced_shape = math_ops.reduced_shape(
185          input_shape, ops.convert_to_tensor(reduced_axes))
186      # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
187      # the shape [2,5,3,4] results in the gradient wrt "abcd".
188      return array_ops.broadcast_to(
189          array_ops.reshape(output_grad, reduced_shape), input_shape)
190
191    # If we *do* have traces or transpose operations, then prepend the extra
192    # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
193    # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
194    #
195    # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
196    # This is the shape of the intermediate "bdca".
197    grad_shape_with_reduced_labels = array_ops.concat(
198        [reduced_dims, array_ops.shape(output_grad)], axis=0)
199    # Obtain the output shape of the reduction-only equation "bdca->ca" as if
200    # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we
201    # just have to prepend that many 1s to the output shape.
202    reduced_shape = (
203        array_ops.concat([
204            array_ops.ones(len(reduced_label_set), dtype=dtypes.int32),
205            array_ops.shape(output_grad)
206        ],
207                         axis=0))
208    # Compute the VJP for the intermediate (viz. "bdca->ca") for which
209    # broadcasting is sufficient.
210    broadcasted_grad = array_ops.broadcast_to(
211        array_ops.reshape(output_grad, reduced_shape),
212        grad_shape_with_reduced_labels)
213    # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use
214    # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd")
215    # since the output axis labels now appear in the input subscripts.
216    return gen_linalg_ops.einsum([broadcasted_grad],
217                                 "{}->{}".format(reduced_subs + output_subs,
218                                                 input_subs))
219
220  def _GetGradWrt(output_grad, other_operand, input_shape, input_subs,
221                  other_subs, output_subs):
222    """Returns the gradient wrt an input operand for a binary einsum.
223
224    This function does not handle (un)broadcasting. This must be done separately
225    on the returned gradient.
226
227    Args:
228      output_grad: The gradient wrt the output of a binary einsum operation.
229      other_operand: The complementary `Tensor` operand i.e. which is not the
230        input operand.
231      input_shape: A `Tensor` representing the shape of input operand.
232      input_subs: The subscripts of the input operand.
233      other_subs: The subscripts of the complementary operand.
234      output_subs: The output subscripts.
235    """
236    # Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
237    #   where the equation involves only Tensor contractions, generalized traces
238    #   and transposes, the input gradients are given by the vector-jacobian
239    #   products (VJPs):
240    #
241    #     grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
242    #     grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
243    #
244    #   where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
245    #   x and y and grad_wrt_z is the given gradient with respect to output z.
246    #
247    # Proof: For unary einsum equations involving only transpose ("ij->ji") and
248    #   traces ("ii->i"), the linear mapping's Jacobian at input x is given
249    #   by the function itself. We can verify that the linear map given by the
250    #   VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
251    #   where the latter represents 'un-tracing', or filling the diagonal with
252    #   the input axis and non-diagonal entries are zeros.
253    #        Furthermore, recall that matrix multiplication, which is
254    #   represented by the equation "ab,bc->ac", has its VJPs given by the
255    #   einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
256    #   https://math.stackexchange.com/a/2755680). Combined with transposes and
257    #   traces we can rewrite Tensor contractions as regular matrix
258    #   multiplication. Since each of these operations have their VJPs described
259    #   by einsums of the required pattern, the result follows.
260    #
261    # Accordingly, einsum operations except for those with reductions, e.g.
262    # "abc,cd->ad" have their VJPs defined by:
263    #   "{output_subs},{other_subs}->{input_subs}".
264    #
265    # But if there is a reduction, this would lead to the equation "ad,cd->abc"
266    # which is invalid because the reduced axis label 'b' is present in the
267    # output but not in any of the inputs. Therefore, we compute the VJP in two
268    # steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
269    # "abc->ac" or, equivalently, reduce_sum(..., axis=1).
270    #
271    # Compute the set of input axis labels which doesn't appear in either the
272    # output subscripts or the other operand's subscript. E.g. the set {'b'} for
273    # the equation "abc,cd->ad".
274    reduced_label_set = set(input_subs).difference(
275        set(output_subs + other_subs + "."))
276    # Obtain the input subscripts with the reduced axis labels removed. E.g.
277    # "ac" in the above example.
278    left_subs = "".join(s for s in input_subs if s not in reduced_label_set)
279
280    # Compute the gradient wrt the input, without accounting for the operation
281    # "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
282    grad_reduced = gen_linalg_ops.einsum([output_grad, other_operand],
283                                         "{},{}->{}".format(
284                                             output_subs, other_subs,
285                                             left_subs))
286    # If the reduced_label_set is empty, then we already have the gradient
287    # wrt the input.
288    if not reduced_label_set:
289      return grad_reduced
290    # Otherwise, we currently have the gradient wrt the output of the reduction
291    # operation "abc->ac". Invoke the subroutine for the gradient for unary
292    # einsum with reductions.
293    return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape,
294                           reduced_label_set)
295
296  equation = op.get_attr("equation")
297  if isinstance(equation, bytes):
298    equation = equation.decode()
299  input_subs, output_subs = equation.split("->")
300
301  if len(op.inputs) == 1:
302    # For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt the
303    # input (VJP) is given by the reversed equation:
304    #   grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
305    # (See the justification in _GetGradWrt). This is valid unless there are
306    # reduced axis labels; i.e. axis labels appearing in the input but not in
307    # the output subscripts.
308    input_shape = array_ops.shape(op.inputs[0])
309    # Find the axis labels which appear only in the input.
310    reduced_label_set = set(input_subs).difference(set(output_subs + ellipsis))
311    if not reduced_label_set:
312      # Return the einsum given by the reversed equation, since we don't have
313      # reduced axes.
314      return gen_linalg_ops.einsum([grad],
315                                   "{}->{}".format(output_subs, input_subs))
316    # We do have reduced axes, so we invoke the subroutine for reduced unary
317    # einsums.
318    return _GetGradReduced(grad, output_subs, input_subs, input_shape,
319                           reduced_label_set)
320
321  x_subs, y_subs = input_subs.split(",")
322  # Add ellipsis for broadcasted dimensions if any operand does not have it.
323  # This is because the equation "...ij,jk->ik" may be valid if the 0th input's
324  # batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
325  # because only the output subscripts contain ellipsis.
326  if ellipsis in output_subs:
327    if ellipsis not in x_subs:
328      x_subs += ellipsis
329    if ellipsis not in y_subs:
330      y_subs += ellipsis
331
332  # Obtain the gradients wrt the inputs x and y, without taking into account
333  # the unbroadcasting.
334  x, y = op.inputs[0], op.inputs[1]
335  if grad.dtype.is_complex:
336    x = math_ops.conj(x)
337    y = math_ops.conj(y)
338
339  x_shape = array_ops.shape(x)
340  y_shape = array_ops.shape(y)
341  grad_x = _GetGradWrt(grad, y, x_shape, x_subs, y_subs, output_subs)
342  grad_y = _GetGradWrt(grad, x, y_shape, y_subs, x_subs, output_subs)
343
344  if ellipsis not in output_subs:
345    # If no ellipsis in the output; then no need to unbroadcast.
346    return grad_x, grad_y
347
348  # Below we handle the case that broadcasting between x and y was necessary,
349  # with x and y having possibly different batch shapes.
350
351  # Obtain the range of axes which map to ellipsis. E.g. for subscripts 'ab...c'
352  # and shape of rank 10; the range [3:-1] denotes the broadcasted axes.
353  bx_start, bx_end = _GetBcastSubshape(x_subs)
354  by_start, by_end = _GetBcastSubshape(y_subs)
355  # If the static batch shapes are equal, we don't need to unbroadcast.
356  x_shape_static = x.get_shape()
357  y_shape_static = y.get_shape()
358  if (x_shape_static.is_fully_defined() and
359      y_shape_static.is_fully_defined() and
360      x_shape_static[bx_start:bx_end] == y_shape_static[by_start:by_end]):
361    return grad_x, grad_y
362
363  # Sum the gradient across the broadcasted axes.
364  rx, ry = array_ops.broadcast_gradient_args(x_shape[bx_start:bx_end],
365                                             y_shape[by_start:by_end])
366  grad_x = array_ops.reshape(
367      math_ops.reduce_sum(grad_x, bx_start + rx), x_shape)
368  grad_y = array_ops.reshape(
369      math_ops.reduce_sum(grad_y, by_start + ry), y_shape)
370  return grad_x, grad_y
371
372
373@ops.RegisterGradient("MatrixDeterminant")
374def _MatrixDeterminantGrad(op, grad):
375  """Gradient for MatrixDeterminant."""
376  a = op.inputs[0]
377  c = op.outputs[0]
378  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
379  multipliers = array_ops.reshape(grad * c,
380                                  array_ops.concat([array_ops.shape(c), [1, 1]],
381                                                   0))
382  return multipliers * a_adj_inv
383
384
385@ops.RegisterGradient("MatrixSquareRoot")
386def _MatrixSquareRootGrad(op, grad):
387  """Gradient for MatrixSquareRoot."""
388
389  # Let A be an m x m square matrix (or batch of matrices)
390  # Let R = sqrtm(A)
391  # By definition, A = RR
392  # Take the differential: dA = d(RR) = RdR + dRR
393  # Solve the resulting Sylvester equation for dR
394
395  # Used to find Kronecker products within the Sylvester equation
396  def _KroneckerProduct(b1, b2):
397    """Computes the Kronecker product of two batches of square matrices."""
398    b1_shape = array_ops.shape(b1)
399    b2_shape = array_ops.shape(b2)
400    b1_order = b1_shape[-1]
401    b2_order = b2_shape[-1]
402
403    shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)]
404    shape_slice = array_ops.slice(b1_shape, [0],
405                                  shape_slice_size)  # Same for both batches
406    b1_reshape_shape = array_ops.concat(
407        [shape_slice, [b1_order], [1], [b1_order], [1]], 0)
408    b2_reshape_shape = array_ops.concat(
409        [shape_slice, [1], [b2_order], [1], [b2_order]], 0)
410
411    b1_reshape = array_ops.reshape(b1, b1_reshape_shape)
412    b2_reshape = array_ops.reshape(b2, b2_reshape_shape)
413
414    order_prod = b1_order * b2_order
415    kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0)
416    return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape)
417
418  sqrtm = op.outputs[0]  # R
419  shape = array_ops.shape(sqrtm)
420  order = shape[-1]  # m
421  matrix_count = math_ops.reduce_prod(shape[0:-2])
422
423  # Get batch of m x m identity matrices
424  eye = linalg_ops.eye(order, dtype=sqrtm.dtype)  # m x m identity matrix
425  eye_flat = array_ops.reshape(eye, [-1])
426  eye_tiled = array_ops.tile(eye_flat, [matrix_count])
427  eye_batch = array_ops.reshape(eye_tiled, shape)
428
429  # The transpose of R is taken in the k1 term instead of k2 in
430  # order to prevent redundant transposition of R (i.e. (R')' = R)
431  sqrtm_transpose = array_ops.matrix_transpose(sqrtm)
432  k1 = _KroneckerProduct(eye_batch, sqrtm_transpose)
433  k2 = _KroneckerProduct(sqrtm, eye_batch)
434  ksum = math_ops.add(k1, k2)
435
436  # Vectorize dA
437  shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)]
438  shape_slice = array_ops.slice(shape, [0], shape_slice_size)
439  shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0)
440  vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da)
441
442  # Solve for vec(dR)
443  vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da)
444
445  # Solve for dR by inverse vectorizing vec(dR)
446  dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape)
447  return array_ops.matrix_transpose(dsqrtm_transpose)
448
449
450@ops.RegisterGradient("LogMatrixDeterminant")
451def _LogMatrixDeterminantGrad(op, _, grad_b):
452  """Gradient for LogMatrixDeterminant."""
453  a = op.inputs[0]
454  c = op.outputs[1]
455  a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
456  multipliers = array_ops.reshape(
457      grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
458  return multipliers * a_adj_inv
459
460
461@ops.RegisterGradient("Cholesky")
462def _CholeskyGrad(op, grad):
463  """Gradient for Cholesky."""
464
465  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
466  l = op.outputs[0]
467  num_rows = array_ops.shape(l)[-1]
468  batch_shape = array_ops.shape(l)[:-2]
469  l_inverse = linalg_ops.matrix_triangular_solve(l,
470                                                 linalg_ops.eye(
471                                                     num_rows,
472                                                     batch_shape=batch_shape,
473                                                     dtype=l.dtype))
474
475  middle = math_ops.matmul(l, grad, adjoint_a=True)
476  middle = array_ops.matrix_set_diag(middle,
477                                     0.5 * array_ops.matrix_diag_part(middle))
478  middle = array_ops.matrix_band_part(middle, -1, 0)
479
480  grad_a = math_ops.matmul(
481      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
482
483  grad_a += _linalg.adjoint(grad_a)
484  return grad_a * 0.5
485
486
487@ops.RegisterGradient("Qr")
488def _QrGrad(op, dq, dr):
489  """Gradient for Qr."""
490
491  # The methodology is explained in detail in https://arxiv.org/abs/2009.10071
492  # QR and LQ Decomposition Matrix Backpropagation Algorithms for
493  # Square, Wide, and Deep, Real and Complex, Matrices and Their Software
494  # Implementation
495  q, r = op.outputs
496  if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
497      r.shape.as_list()[-1] is None):
498    raise NotImplementedError("QrGrad not implemented with dynamic shapes. "
499                              f"Received r.shape: {r.shape}")
500  if (r.shape.dims[-2].value > r.shape.dims[-1].value and
501      q.shape.dims[-2].value == q.shape.dims[-1].value):
502    raise NotImplementedError("QrGrad not implemented when nrows > ncols "
503                              "and full_matrices is true. Received r.shape="
504                              f"{r.shape} with nrows={r.shape.dims[-2]}"
505                              f"and ncols={r.shape.dims[-1]}.")
506
507  def _TriangularSolve(x, r):
508    """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
509    return _linalg.adjoint(
510        linalg_ops.matrix_triangular_solve(
511            r, _linalg.adjoint(x), lower=False, adjoint=False))
512
513  def _QrGradSquareAndDeepMatrices(q, r, dq, dr):
514    """Gradient for matrix orders num_rows >= num_cols
515    and full_matrices is false.
516    """
517    qdq = math_ops.matmul(q, dq, adjoint_a=True)
518    qdq_ = qdq - _linalg.adjoint(qdq)
519    rdr = math_ops.matmul(r, dr, adjoint_b=True)
520    rdr_ = rdr - _linalg.adjoint(rdr)
521    tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
522
523    grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
524    grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
525    ret = grad_a + grad_b
526
527    if q.dtype.is_complex:
528      # need to add a correction to the gradient formula for complex case
529      m = rdr - _linalg.adjoint(qdq)
530      eyem = _linalg.set_diag(array_ops.zeros_like(m), _linalg.diag_part(m))
531      correction = eyem - math_ops.cast(math_ops.real(eyem), q.dtype)
532      ret = ret + _TriangularSolve(
533          math_ops.matmul(q, _linalg.adjoint(correction)), r)
534
535    return ret
536
537  num_rows, num_cols = q.shape.dims[-2].value, r.shape.dims[-1]
538
539  if num_rows >= num_cols:
540    return _QrGradSquareAndDeepMatrices(q, r, dq, dr)
541
542  # Partition a = [x, y], r = [u, v] and reduce to the square case
543  a = op.inputs[0]
544  y = a[..., :, num_rows:]
545  u = r[..., :, :num_rows]
546  dv = dr[..., :, num_rows:]
547  du = dr[..., :, :num_rows]
548  dy = math_ops.matmul(q, dv)
549  dx = _QrGradSquareAndDeepMatrices(q, u,
550                                    dq + math_ops.matmul(y, dv, adjoint_b=True),
551                                    du)
552  return array_ops.concat([dx, dy], axis=-1)
553
554
555@ops.RegisterGradient("MatrixSolve")
556def _MatrixSolveGrad(op, grad):
557  """Gradient for MatrixSolve."""
558  a = op.inputs[0]
559  adjoint_a = op.get_attr("adjoint")
560  c = op.outputs[0]
561  grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
562  if adjoint_a:
563    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
564  else:
565    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
566  return (grad_a, grad_b)
567
568
569@ops.RegisterGradient("MatrixSolveLs")
570def _MatrixSolveLsGrad(op, grad):
571  """Gradients for MatrixSolveLs."""
572
573  # TODO(rmlarsen): The implementation could be more efficient:
574  #   a) Output the Cholesky factorization from forward op instead of
575  #      recomputing it here.
576  #   b) Implement a symmetric rank-k update op instead of computing
577  #      x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
578
579  def _Overdetermined(op, grad):
580    """Gradients for the overdetermined case of MatrixSolveLs.
581
582    This is the backprop for the solution to the normal equations of the first
583    kind:
584       X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
585    which solve the least squares problem
586       min ||A * X - B||_F^2 + lambda ||X||_F^2.
587    """
588    a = op.inputs[0]
589    b = op.inputs[1]
590    x = op.outputs[0]
591    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
592    # pylint: disable=protected-access
593    chol = linalg_ops._RegularizedGramianCholesky(
594        a, l2_regularizer=l2_regularizer, first_kind=True)
595    # pylint: enable=protected-access
596    # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
597    z = linalg_ops.cholesky_solve(chol, grad)
598    xzt = math_ops.matmul(x, z, adjoint_b=True)
599    zx_sym = xzt + array_ops.matrix_transpose(xzt)
600    grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
601    grad_b = math_ops.matmul(a, z)
602    return (grad_a, grad_b, None)
603
604  def _Underdetermined(op, grad):
605    """Gradients for the underdetermined case of MatrixSolveLs.
606
607    This is the backprop for the solution to the normal equations of the second
608    kind:
609      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
610    that (for lambda=0) solve the least squares problem
611      min ||X||_F subject to A*X = B.
612    """
613    a = op.inputs[0]
614    b = op.inputs[1]
615    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
616    # pylint: disable=protected-access
617    chol = linalg_ops._RegularizedGramianCholesky(
618        a, l2_regularizer=l2_regularizer, first_kind=False)
619    # pylint: enable=protected-access
620    grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
621    # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
622    tmp = linalg_ops.cholesky_solve(chol, b)
623    a1 = math_ops.matmul(tmp, a, adjoint_a=True)
624    a1 = -math_ops.matmul(grad_b, a1)  # pylint: disable=invalid-unary-operand-type
625    a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True)
626    a2 = math_ops.matmul(tmp, a2, adjoint_b=True)
627    grad_a = a1 + a2
628    return (grad_a, grad_b, None)
629
630  fast = op.get_attr("fast")
631  if fast is False:
632    raise ValueError("Gradient not defined for fast=False")
633  matrix_shape = op.inputs[0].get_shape()[-2:]
634  if matrix_shape.is_fully_defined():
635    if matrix_shape[-2] >= matrix_shape[-1]:
636      return _Overdetermined(op, grad)
637    else:
638      return _Underdetermined(op, grad)
639  else:
640    # We have to defer determining the shape to runtime and use
641    # conditional execution of the appropriate graph.
642    matrix_shape = array_ops.shape(op.inputs[0])[-2:]
643    return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
644                                 lambda: _Overdetermined(op, grad),
645                                 lambda: _Underdetermined(op, grad))
646
647
648@ops.RegisterGradient("BandedTriangularSolve")
649def _BandedTriangularSolveGrad(op, grad):
650  """Gradient for BandedTriangularSolve."""
651  a = op.inputs[0]
652  b = op.inputs[1]
653  num_bands = array_ops.shape(a)[-2]
654  adjoint_a = op.get_attr("adjoint")
655  lower_a = op.get_attr("lower")
656  c = op.outputs[0]
657  grad_b = linalg_ops.banded_triangular_solve(
658      a, grad, lower=lower_a, adjoint=not adjoint_a)
659  if adjoint_a:
660    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
661  else:
662    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
663  if lower_a:
664    grad_a = array_ops.matrix_diag_part(
665        grad_a, k=(-(num_bands - 1), 0), align="LEFT_RIGHT")
666  else:
667    grad_a = array_ops.matrix_diag_part(
668        grad_a, k=(0, num_bands - 1), align="LEFT_RIGHT")
669  # If the static batch shapes are equal, we don't need to unbroadcast.
670  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
671      a.shape[:-2] == b.shape[:-2]):
672    return grad_a, grad_b
673  a_shape = array_ops.shape(a)
674  b_shape = array_ops.shape(b)
675  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
676  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
677  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
678  return grad_a, grad_b
679
680
681@ops.RegisterGradient("MatrixTriangularSolve")
682def _MatrixTriangularSolveGrad(op, grad):
683  """Gradient for MatrixTriangularSolve."""
684  a = op.inputs[0]
685  b = op.inputs[1]
686  adjoint_a = op.get_attr("adjoint")
687  lower_a = op.get_attr("lower")
688  c = op.outputs[0]
689  grad_b = linalg_ops.matrix_triangular_solve(
690      a, grad, lower=lower_a, adjoint=not adjoint_a)
691  if adjoint_a:
692    grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
693  else:
694    grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True)  # pylint: disable=invalid-unary-operand-type
695  if lower_a:
696    grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
697  else:
698    grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
699  # If the static batch shapes are equal, we don't need to unbroadcast.
700  if (a.shape.is_fully_defined() and b.shape.is_fully_defined() and
701      a.shape[:-2] == b.shape[:-2]):
702    return grad_a, grad_b
703  a_shape = array_ops.shape(a)
704  b_shape = array_ops.shape(b)
705  ra, rb = array_ops.broadcast_gradient_args(a_shape[:-2], b_shape[:-2])
706  grad_a = array_ops.reshape(math_ops.reduce_sum(grad_a, axis=ra), a_shape)
707  grad_b = array_ops.reshape(math_ops.reduce_sum(grad_b, axis=rb), b_shape)
708  return grad_a, grad_b
709
710
711# To avoid nan in cases with degenerate eigenvalues or
712# degenerate/zero singular values in calculations of
713# f and s_inv_mat, we introduce a Lorentz broadening.
714def _SafeReciprocal(x, epsilon=1E-20):
715  return x * math_ops.reciprocal(x * x + epsilon)
716
717
718@ops.RegisterGradient("Eig")
719def _EigGrad(op, grad_e, grad_v):
720  """Gradient for Eig.
721
722  Based on eq. 4.77 from paper by
723  Christoph Boeddeker et al.
724  https://arxiv.org/abs/1701.00392
725  See also
726  "Computation of eigenvalue and eigenvector derivatives
727  for a general complex-valued eigensystem" by Nico van der Aa.
728  As for now only distinct eigenvalue case is considered.
729  """
730  e = op.outputs[0]
731  compute_v = op.get_attr("compute_v")
732  # a = op.inputs[0], which satisfies
733  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
734  with ops.control_dependencies([grad_e, grad_v]):
735    if compute_v:
736      v = op.outputs[1]
737      vt = _linalg.adjoint(v)
738      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
739      # Notice that because of the term involving f, the gradient becomes
740      # infinite (or NaN in practice) when eigenvalues are not unique.
741      # Mathematically this should not be surprising, since for (k-fold)
742      # degenerate eigenvalues, the corresponding eigenvectors are only defined
743      # up to arbitrary rotation in a (k-dimensional) subspace.
744      f = array_ops.matrix_set_diag(
745          _SafeReciprocal(
746              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
747          array_ops.zeros_like(e))
748      f = math_ops.conj(f)
749      vgv = math_ops.matmul(vt, grad_v)
750      mid = array_ops.matrix_diag(grad_e)
751      diag_grad_part = array_ops.matrix_diag(
752          array_ops.matrix_diag_part(
753              math_ops.cast(math_ops.real(vgv), vgv.dtype)))
754      mid += f * (vgv - math_ops.matmul(math_ops.matmul(vt, v), diag_grad_part))
755      # vt is formally invertible as long as the original matrix is
756      # diagonalizable. However, in practice, vt may
757      # be ill-conditioned when matrix original matrix is close to
758      # non-diagonalizable one
759      grad_a = linalg_ops.matrix_solve(vt, math_ops.matmul(mid, vt))
760    else:
761      _, v = linalg_ops.eig(op.inputs[0])
762      vt = _linalg.adjoint(v)
763      # vt is formally invertible as long as the original matrix is
764      # diagonalizable. However, in practice, vt may
765      # be ill-conditioned when matrix original matrix is close to
766      # non-diagonalizable one
767      grad_a = linalg_ops.matrix_solve(
768          vt, math_ops.matmul(array_ops.matrix_diag(grad_e), vt))
769    return math_ops.cast(grad_a, op.inputs[0].dtype)
770
771
772@ops.RegisterGradient("SelfAdjointEigV2")
773def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
774  """Gradient for SelfAdjointEigV2."""
775  e = op.outputs[0]
776  compute_v = op.get_attr("compute_v")
777  # a = op.inputs[0], which satisfies
778  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
779  with ops.control_dependencies([grad_e, grad_v]):
780    if compute_v:
781      v = op.outputs[1]
782      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
783      # Notice that because of the term involving f, the gradient becomes
784      # infinite (or NaN in practice) when eigenvalues are not unique.
785      # Mathematically this should not be surprising, since for (k-fold)
786      # degenerate eigenvalues, the corresponding eigenvectors are only defined
787      # up to arbitrary rotation in a (k-dimensional) subspace.
788      f = array_ops.matrix_set_diag(
789          _SafeReciprocal(
790              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
791          array_ops.zeros_like(e))
792      grad_a = math_ops.matmul(
793          v,
794          math_ops.matmul(
795              array_ops.matrix_diag(grad_e) +
796              f * math_ops.matmul(v, grad_v, adjoint_a=True),
797              v,
798              adjoint_b=True))
799    else:
800      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
801      grad_a = math_ops.matmul(v,
802                               math_ops.matmul(
803                                   array_ops.matrix_diag(grad_e),
804                                   v,
805                                   adjoint_b=True))
806    # The forward op only depends on the lower triangular part of a, so here we
807    # symmetrize and take the lower triangle
808    grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0)
809    grad_a = array_ops.matrix_set_diag(grad_a,
810                                       0.5 * array_ops.matrix_diag_part(grad_a))
811    return grad_a
812
813
814@ops.RegisterGradient("Svd")
815def _SvdGrad(op, grad_s, grad_u, grad_v):
816  """Gradient for the singular value decomposition."""
817
818  # The derivation for the compute_uv=False case, and most of
819  # the derivation for the full_matrices=True case, are in
820  # Giles' paper (see reference at top of file).  A derivation for
821  # the full_matrices=False case is available at
822  # https://j-towns.github.io/papers/svd-derivative.pdf
823  # The derivation for complex valued SVD can be found in
824  # https://re-ra.xyz/misc/complexsvd.pdf or
825  # https://giggleliu.github.io/2019/04/02/einsumbp.html
826  a = op.inputs[0]
827  a_shape = a.get_shape().with_rank_at_least(2)
828  grad_s = math_ops.cast(grad_s, a.dtype)
829  grad_s_mat = array_ops.matrix_diag(grad_s)
830
831  if not op.get_attr("compute_uv"):
832    s, u, v = linalg_ops.svd(a, compute_uv=True)
833    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
834    grad_a.set_shape(a_shape)
835    return grad_a
836
837  full_matrices = op.get_attr("full_matrices")
838
839  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
840  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
841  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
842  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
843  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
844      grad_v_shape[:-2])
845  a_shape = batch_shape.concatenate([m, n])
846
847  m = a_shape.dims[-2].value
848  n = a_shape.dims[-1].value
849  # TODO(rmlarsen): Make this work with placeholders.
850  if m is None or n is None:
851    raise NotImplementedError(
852        "SVD gradient has not been implemented for input with unknown "
853        "inner matrix shape.")
854
855  s = op.outputs[0]
856  u = op.outputs[1]
857  v = op.outputs[2]
858  s = math_ops.cast(s, a.dtype)
859
860  use_adjoint = False
861  if m > n:
862    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
863    # Hermitian transpose of the gradient at the end.
864    use_adjoint = True
865    m, n = n, m
866    u, v = v, u
867    grad_u, grad_v = grad_v, grad_u
868
869  with ops.control_dependencies([grad_s, grad_u, grad_v]):
870    if full_matrices and abs(m - n) > 1:
871      raise NotImplementedError(
872          "svd gradient is not implemented for abs(m - n) > 1 "
873          f"when full_matrices is True. Received: m={m} and n={n} from "
874          f"op input={a} with shape={a_shape}.")
875    s_mat = array_ops.matrix_diag(s)
876    s2 = math_ops.square(s)
877
878    # NOTICE: Because of the term involving f, the gradient becomes
879    # infinite (or NaN in practice) when singular values are not unique.
880    # Mathematically this should not be surprising, since for (k-fold)
881    # degenerate singular values, the corresponding singular vectors are
882    # only defined up a (k-dimensional) subspace. In practice, this can
883    # lead to numerical instability when singular values are close but not
884    # exactly equal.
885
886    s_shape = array_ops.shape(s)
887    f = array_ops.matrix_set_diag(
888        _SafeReciprocal(
889            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
890        array_ops.zeros_like(s))
891    s_inv_mat = array_ops.matrix_diag(_SafeReciprocal(s))
892
893    v1 = v[..., :, :m]
894    grad_v1 = grad_v[..., :, :m]
895
896    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
897    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)
898
899    f_u = f * u_gu
900    f_v = f * v_gv
901
902    term1_nouv = (
903        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
904        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))
905
906    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))
907
908    if m == n:
909      grad_a_before_transpose = term1
910    else:
911      gv1t = array_ops.matrix_transpose(grad_v1, conjugate=True)
912      gv1t_v1 = math_ops.matmul(gv1t, v1)
913      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
914
915      if full_matrices:
916        v2 = v[..., :, m:n]
917        grad_v2 = grad_v[..., :, m:n]
918
919        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
920        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)
921
922      u_s_inv = math_ops.matmul(u, s_inv_mat)
923      term2 = math_ops.matmul(u_s_inv, term2_nous)
924
925      grad_a_before_transpose = term1 + term2
926
927    if a.dtype.is_complex:
928      eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=a.dtype)
929      l = eye * v_gv
930      term3_nouv = math_ops.matmul(s_inv_mat, _linalg.adjoint(l) - l)
931      term3 = 1 / 2. * math_ops.matmul(
932          u, math_ops.matmul(term3_nouv, v1, adjoint_b=True))
933
934      grad_a_before_transpose += term3
935
936    if use_adjoint:
937      grad_a = array_ops.matrix_transpose(
938          grad_a_before_transpose, conjugate=True)
939    else:
940      grad_a = grad_a_before_transpose
941
942    grad_a.set_shape(a_shape)
943    return grad_a
944
945
946def _LeftShift(x):
947  """Shifts next-to-last dimension to the left, adding zero on the right."""
948  rank = array_ops.rank(x)
949  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
950  pad = array_ops.concat([zeros, array_ops.constant([[0, 1], [0, 0]])], axis=0)
951  return array_ops.pad(x[..., 1:, :], pad)
952
953
954def _RightShift(x):
955  """Shifts next-to-last dimension to the right, adding zero on the left."""
956  rank = array_ops.rank(x)
957  zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
958  pad = array_ops.concat([zeros, array_ops.constant([[1, 0], [0, 0]])], axis=0)
959  return array_ops.pad(x[..., :-1, :], pad)
960
961
962@ops.RegisterGradient("TridiagonalMatMul")
963def _TridiagonalMatMulGrad(op, grad):
964  """Gradient for TridiagonalMatMul."""
965  superdiag_conj = array_ops.matrix_transpose(op.inputs[0], conjugate=True)
966  maindiag_conj = array_ops.matrix_transpose(op.inputs[1], conjugate=True)
967  subdiag_conj = array_ops.matrix_transpose(op.inputs[2], conjugate=True)
968  rhs_conj = math_ops.conj(op.inputs[3])
969
970  superdiag_grad = math_ops.reduce_sum(_LeftShift(rhs_conj) * grad, axis=-1)
971  maindiag_grad = math_ops.reduce_sum(rhs_conj * grad, axis=-1)
972  subdiag_grad = math_ops.reduce_sum(_RightShift(rhs_conj) * grad, axis=-1)
973  rhs_grad = _RightShift(superdiag_conj * grad) + \
974      maindiag_conj * grad + _LeftShift(subdiag_conj * grad)
975
976  superdiag_grad = array_ops.expand_dims(superdiag_grad, -2)
977  maindiag_grad = array_ops.expand_dims(maindiag_grad, -2)
978  subdiag_grad = array_ops.expand_dims(subdiag_grad, -2)
979
980  return superdiag_grad, maindiag_grad, subdiag_grad, rhs_grad
981
982
983@ops.RegisterGradient("TridiagonalSolve")
984def _TridiagonalSolveGrad(op, grad):
985  """Gradient for TridiagonalSolveGrad."""
986  diags = op.inputs[0]
987  x = op.outputs[0]
988  partial_pivoting = op.get_attr("partial_pivoting")
989  perturb_singular = op.get_attr("perturb_singular")
990
991  # Transposing the matrix within tridiagonal_solve kernel by interchanging
992  # superdiagonal and subdiagonal wouldn't work on GPU due to mismatch with
993  # paddings required by cusparse*gtsv routines.
994  # So constructing the transposed matrix in Python.
995  diags_transposed = _TransposeTridiagonalMatrix(diags)
996
997  grad_rhs = linalg_ops.tridiagonal_solve(
998      diags_transposed,
999      grad,
1000      partial_pivoting=partial_pivoting,
1001      perturb_singular=perturb_singular)
1002  grad_diags = -_MatmulExtractingThreeDiagonals(grad_rhs, x)  # pylint: disable=invalid-unary-operand-type
1003  return grad_diags, grad_rhs
1004
1005
1006def _TransposeTridiagonalMatrix(diags):
1007  """Transposes a tridiagonal matrix.
1008
1009  Args:
1010    diags: the diagonals of the input matrix in the compact form (see
1011      linalg_ops.tridiagonal_solve).
1012
1013  Returns:
1014    Diagonals of the transposed matrix in the compact form.
1015  """
1016
1017  diag = diags[..., 1, :]
1018
1019  if diags.shape.is_fully_defined():
1020    # For fully defined tensor we can concat with a tensor of zeros, which is
1021    # faster than using array_ops.pad().
1022    zeros = array_ops.zeros(list(diags.shape[:-2]) + [1], dtype=diags.dtype)
1023    superdiag = array_ops.concat((diags[..., 2, 1:], zeros), axis=-1)
1024    subdiag = array_ops.concat((zeros, diags[..., 0, :-1]), axis=-1)
1025  else:
1026    rank = array_ops.rank(diags)
1027    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1028    superdiag_pad = array_ops.concat((zeros, array_ops.constant([[0, 1]])),
1029                                     axis=0)
1030    superdiag = array_ops.pad(diags[..., 2, 1:], superdiag_pad)
1031    subdiag_pad = array_ops.concat((zeros, array_ops.constant([[1, 0]])),
1032                                   axis=0)
1033    subdiag = array_ops.pad(diags[..., 0, :-1], subdiag_pad)
1034  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1035
1036
1037def _MatmulExtractingThreeDiagonals(x, y_tr):
1038  """Multiplies matrices and extracts three diagonals from the product.
1039
1040  With sizes M x K and K x M, this function takes O(MK) time and O(M) space,
1041  while using math_ops.matmul, and then extracting the diagonals would take
1042  O(M^2 K) time and O(M^2) space.
1043
1044  Args:
1045    x: first matrix
1046    y_tr: second matrix transposed
1047
1048  Returns:
1049    Diagonals of the product in compact format (see
1050    linalg_ops.tridiagonal_solve)
1051
1052  """
1053  diag = math_ops.reduce_sum(x * y_tr, axis=-1)
1054
1055  if y_tr.shape.is_fully_defined():
1056    zeros = array_ops.zeros(
1057        list(x.shape[:-2]) + [1, x.shape[-1]], dtype=x.dtype)
1058    superdiag = math_ops.reduce_sum(
1059        x * array_ops.concat((y_tr[..., 1:, :], zeros), axis=-2), axis=-1)
1060    subdiag = math_ops.reduce_sum(
1061        x * array_ops.concat((zeros, y_tr[..., :-1, :]), axis=-2), axis=-1)
1062  else:
1063    rank = array_ops.rank(y_tr)
1064    zeros = array_ops.zeros((rank - 2, 2), dtype=dtypes.int32)
1065    superdiag_pad = array_ops.concat(
1066        (zeros, array_ops.constant([[0, 1], [0, 0]])), axis=0)
1067    superdiag = math_ops.reduce_sum(
1068        x * array_ops.pad(y_tr[..., 1:, :], superdiag_pad), axis=-1)
1069    subdiag_pad = array_ops.concat(
1070        (zeros, array_ops.constant([[1, 0], [0, 0]])), axis=0)
1071    subdiag = math_ops.reduce_sum(
1072        x * array_ops.pad(y_tr[..., :-1, :], subdiag_pad), axis=-1)
1073  return array_ops.stack([superdiag, diag, subdiag], axis=-2)
1074