xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/linalg/linear_operator_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Internal utilities for `LinearOperator` classes."""
16
17import numpy as np
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.module import module
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import check_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import linalg_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import variables as variables_module
28from tensorflow.python.util import nest
29
30
31################################################################################
32# To make more friendly for TF2.
33################################################################################
34
35
36def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
37  """Converts the given `value` to a `Tensor` if input is nonreference type.
38
39  This function converts Python objects of various types to `Tensor` objects
40  except if the input has nonreference semantics. Reference semantics are
41  characterized by `is_ref` and is any object which is a
42  `tf.Variable` or instance of `tf.Module`. This function accepts any input
43  which `tf.convert_to_tensor` would also.
44
45  Note: This function diverges from default Numpy behavior for `float` and
46    `string` types when `None` is present in a Python list or scalar. Rather
47    than silently converting `None` values, an error will be thrown.
48
49  Args:
50    value: An object whose type has a registered `Tensor` conversion function.
51    dtype: Optional element type for the returned tensor. If missing, the
52      type is inferred from the type of `value`.
53    dtype_hint: Optional element type for the returned tensor,
54      used when dtype is None. In some cases, a caller may not have a
55      dtype in mind when converting to a tensor, so dtype_hint
56      can be used as a soft preference.  If the conversion to
57      `dtype_hint` is not possible, this argument has no effect.
58    name: Optional name to use if a new `Tensor` is created.
59
60  Returns:
61    tensor: A `Tensor` based on `value`.
62
63  Raises:
64    TypeError: If no conversion function is registered for `value` to `dtype`.
65    RuntimeError: If a registered conversion function returns an invalid value.
66    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
67
68
69  #### Examples:
70
71  ```python
72
73  x = tf.Variable(0.)
74  y = convert_nonref_to_tensor(x)
75  x is y
76  # ==> True
77
78  x = tf.constant(0.)
79  y = convert_nonref_to_tensor(x)
80  x is y
81  # ==> True
82
83  x = np.array(0.)
84  y = convert_nonref_to_tensor(x)
85  x is y
86  # ==> False
87  tf.is_tensor(y)
88  # ==> True
89
90  x = tfp.util.DeferredTensor(13.37, lambda x: x)
91  y = convert_nonref_to_tensor(x)
92  x is y
93  # ==> True
94  tf.is_tensor(y)
95  # ==> False
96  tf.equal(y, 13.37)
97  # ==> True
98  ```
99
100  """
101  # We explicitly do not use a tf.name_scope to avoid graph clutter.
102  if value is None:
103    return None
104  if is_ref(value):
105    if dtype is None:
106      return value
107    dtype_base = base_dtype(dtype)
108    value_dtype_base = base_dtype(value.dtype)
109    if dtype_base != value_dtype_base:
110      raise TypeError(
111          f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` "
112          f"Received: `{dtype_name(value_dtype_base)}`.")
113    return value
114  return ops.convert_to_tensor_v2_with_dispatch(
115      value, dtype=dtype, dtype_hint=dtype_hint, name=name)
116
117
118def base_dtype(dtype):
119  """Returns a non-reference `dtype` based on this `dtype`."""
120  dtype = dtypes.as_dtype(dtype)
121  if hasattr(dtype, "base_dtype"):
122    return dtype.base_dtype
123  return dtype
124
125
126def dtype_name(dtype):
127  """Returns the string name for this `dtype`."""
128  dtype = dtypes.as_dtype(dtype)
129  if hasattr(dtype, "name"):
130    return dtype.name
131  if hasattr(dtype, "__name__"):
132    return dtype.__name__
133  return str(dtype)
134
135
136def check_dtype(arg, dtype):
137  """Check that arg.dtype == self.dtype."""
138  if arg.dtype.base_dtype != dtype:
139    raise TypeError(
140        f"Expected argument to have dtype {dtype}. Found: {arg.dtype} in "
141        f"tensor {arg}.")
142
143
144def is_ref(x):
145  """Evaluates if the object has reference semantics.
146
147  An object is deemed "reference" if it is a `tf.Variable` instance or is
148  derived from a `tf.Module` with `dtype` and `shape` properties.
149
150  Args:
151    x: Any object.
152
153  Returns:
154    is_ref: Python `bool` indicating input is has nonreference semantics, i.e.,
155      is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties.
156  """
157  return (
158      # Note: we check that tf.Variable is a class because we might be using a
159      # different backend other than TF.
160      isinstance(x, variables_module.Variable) or
161      (isinstance(x, module.Module) and hasattr(x, "dtype") and
162       hasattr(x, "shape")))
163
164
165def assert_not_ref_type(x, arg_name):
166  if is_ref(x):
167    raise TypeError(
168        f"Argument {arg_name} cannot be reference type. Found: {type(x)}.")
169
170
171################################################################################
172# Asserts.
173################################################################################
174
175
176def assert_no_entries_with_modulus_zero(
177    x, message=None, name="assert_no_entries_with_modulus_zero"):
178  """Returns `Op` that asserts Tensor `x` has no entries with modulus zero.
179
180  Args:
181    x:  Numeric `Tensor`, real, integer, or complex.
182    message:  A string message to prepend to failure message.
183    name:  A name to give this `Op`.
184
185  Returns:
186    An `Op` that asserts `x` has no entries with modulus zero.
187  """
188  with ops.name_scope(name, values=[x]):
189    x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
190    dtype = x.dtype.base_dtype
191    should_be_nonzero = math_ops.abs(x)
192    zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
193    return check_ops.assert_less(zero, should_be_nonzero, message=message)
194
195
196def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
197  """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.
198
199  Args:
200    x:  Numeric `Tensor`, real, integer, or complex.
201    message:  A string message to prepend to failure message.
202    name:  A name to give this `Op`.
203
204  Returns:
205    An `Op` that asserts `x` has no entries with modulus zero.
206  """
207  with ops.name_scope(name, values=[x]):
208    x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
209    dtype = x.dtype.base_dtype
210
211    if dtype.is_floating:
212      return control_flow_ops.no_op()
213
214    zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
215    return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
216
217
218def assert_compatible_matrix_dimensions(operator, x):
219  """Assert that an argument to solve/matmul has proper domain dimension.
220
221  If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then
222  `operator.matmul(x)` is defined only if `N = Q`.  This `Op` returns an
223  `Assert` that "fires" if this is not the case.  Static checks are already
224  done by the base class `LinearOperator`.
225
226  Args:
227    operator:  `LinearOperator`.
228    x:  `Tensor`.
229
230  Returns:
231    `Assert` `Op`.
232  """
233  # Static checks are done in the base class.  Only tensor asserts here.
234  assert_same_dd = check_ops.assert_equal(
235      array_ops.shape(x)[-2],
236      operator.domain_dimension_tensor(),
237      # This error message made to look similar to error raised by static check
238      # in the base class.
239      message=("Dimensions are not compatible.  "
240               "shape[-2] of argument to be the same as this operator"))
241
242  return assert_same_dd
243
244
245def assert_is_batch_matrix(tensor):
246  """Static assert that `tensor` has rank `2` or higher."""
247  sh = tensor.shape
248  if sh.ndims is not None and sh.ndims < 2:
249    raise ValueError(
250        f"Expected [batch] matrix to have at least two dimensions. Found: "
251        f"{tensor}.")
252
253
254def shape_tensor(shape, name=None):
255  """Convert Tensor using default type, unless empty list or tuple."""
256  # Works just like random_ops._ShapeTensor.
257  if isinstance(shape, (tuple, list)) and not shape:
258    dtype = dtypes.int32
259  else:
260    dtype = None
261  return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name)
262
263
264################################################################################
265# Broadcasting versions of common linear algebra functions.
266# TODO(b/77519145) Do this more efficiently in some special cases.
267################################################################################
268
269
270def broadcast_matrix_batch_dims(batch_matrices, name=None):
271  """Broadcast leading dimensions of zero or more [batch] matrices.
272
273  Example broadcasting one batch dim of two simple matrices.
274
275  ```python
276  x = [[1, 2],
277       [3, 4]]  # Shape [2, 2], no batch dims
278
279  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]
280
281  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
282
283  x_bc
284  ==> [[[1, 2],
285        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].
286
287  y_bc
288  ==> same as y
289  ```
290
291  Example broadcasting many batch dims
292
293  ```python
294  x = tf.random.normal(shape=(2, 3, 1, 4, 4))
295  y = tf.random.normal(shape=(1, 3, 2, 5, 5))
296  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
297
298  x_bc.shape
299  ==> (2, 3, 2, 4, 4)
300
301  y_bc.shape
302  ==> (2, 3, 2, 5, 5)
303  ```
304
305  Args:
306    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
307    name:  A string name to prepend to created ops.
308
309  Returns:
310    bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing
311      the values from `batch_matrices[i]`, with possibly broadcast batch dims.
312
313  Raises:
314    ValueError:  If any input `Tensor` is statically determined to have less
315      than two dimensions.
316  """
317  with ops.name_scope(
318      name or "broadcast_matrix_batch_dims", values=batch_matrices):
319    check_ops.assert_proper_iterable(batch_matrices)
320    batch_matrices = list(batch_matrices)
321
322    for i, mat in enumerate(batch_matrices):
323      batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat)
324      assert_is_batch_matrix(batch_matrices[i])
325
326    if len(batch_matrices) < 2:
327      return batch_matrices
328
329    # Try static broadcasting.
330    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
331    # E.g. if batch_matrices = [x, y], with
332    # x.shape =    [2, j, k]  (batch shape =    [2])
333    # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
334    # ==> bcast_batch_shape = [3, 2]
335    bcast_batch_shape = batch_matrices[0].shape[:-2]
336    for mat in batch_matrices[1:]:
337      bcast_batch_shape = array_ops.broadcast_static_shape(
338          bcast_batch_shape,
339          mat.shape[:-2])
340    if bcast_batch_shape.is_fully_defined():
341      for i, mat in enumerate(batch_matrices):
342        if mat.shape[:-2] != bcast_batch_shape:
343          bcast_shape = array_ops.concat(
344              [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0)
345          batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape)
346      return batch_matrices
347
348    # Since static didn't work, do dynamic, which always copies data.
349    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
350    for mat in batch_matrices[1:]:
351      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
352          bcast_batch_shape,
353          array_ops.shape(mat)[:-2])
354    for i, mat in enumerate(batch_matrices):
355      batch_matrices[i] = array_ops.broadcast_to(
356          mat,
357          array_ops.concat(
358              [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0))
359
360    return batch_matrices
361
362
363def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
364  """Solve systems of linear equations."""
365  with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
366    matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
367    rhs = ops.convert_to_tensor_v2_with_dispatch(
368        rhs, name="rhs", dtype=matrix.dtype)
369
370    # If either matrix/rhs has extra dims, we can reshape to get rid of them.
371    matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
372        matrix, rhs, adjoint_a=adjoint)
373
374    # This will broadcast by brute force if we still need to.
375    matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
376
377    solution = linalg_ops.matrix_solve(
378        matrix, rhs, adjoint=adjoint and still_need_to_transpose)
379
380    return reshape_inv(solution)
381
382
383def _reshape_for_efficiency(a,
384                            b,
385                            transpose_a=False,
386                            transpose_b=False,
387                            adjoint_a=False,
388                            adjoint_b=False):
389  """Maybe reshape a, b, and return an inverse map.  For matmul/solve."""
390  def identity(x):
391    return x
392
393  # At this point, we have not taken transpose/adjoint of a/b.
394  still_need_to_transpose = True
395
396  if a.shape.ndims is None or b.shape.ndims is None:
397    return a, b, identity, still_need_to_transpose
398
399  # This could be handled in the future, but seems less common.
400  if a.shape.ndims >= b.shape.ndims:
401    return a, b, identity, still_need_to_transpose
402
403  # From now on, we might modify b, but will not modify a.
404
405  # Suppose:
406  #   a.shape =     C + [m, n], b.shape =
407  #   b.shape = S + C + [n, r]
408  b_extra_ndims = b.shape.ndims - a.shape.ndims
409
410  # b_extra_sh = S, b_main_sh = C + [n, r]
411  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
412  b_main_sh = array_ops.shape(b)[b_extra_ndims:]
413
414  # No reason to flip unless the extra dims of b are big enough.  Why?
415  # Assume adjoint/transpose = False.  Then...
416  # By not flipping, we have to replicate a to shape
417  #   b_extra_sh + a.shape,
418  # which could use extra memory.  But in all cases, the final output has shape
419  #   b_extra_sh + a.shape[:-1] + [b.shape[-1]]
420  # So we only end up creating a larger object if the end dim of b is smaller
421  # than the end dim of a.  This often happens, e.g. if b was a vector that was
422  # expanded to a matrix (by appending a singleton).
423
424  # Since adjoint/transpose may not be False, we must make adjustments here.
425  # The dim of b that holds the multiple equations.
426  a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1]
427  b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1]
428  b_extra_sz_ = (
429      np.prod(b.shape[:b_extra_ndims].as_list())
430      if b.shape[:b_extra_ndims].is_fully_defined() else None)
431  if (a_domain_sz_ is not None and b_eq_sz_ is not None and
432      b_extra_sz_ is not None):
433    if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_:
434      return a, b, identity, still_need_to_transpose
435
436  # At this point, we're flipping for sure!
437  # Any transposes/adjoints will happen here explicitly, rather than in calling
438  # code.  Why?  To avoid having to write separate complex code for each case.
439  if adjoint_a:
440    a = array_ops.matrix_transpose(a, conjugate=True)
441  elif transpose_a:
442    a = array_ops.matrix_transpose(a, conjugate=False)
443  if adjoint_b:
444    b = array_ops.matrix_transpose(b, conjugate=True)
445  elif transpose_a:
446    b = array_ops.matrix_transpose(b, conjugate=False)
447  still_need_to_transpose = False
448
449  # Recompute shapes, since the transpose/adjoint may have changed them.
450  b_extra_sh = array_ops.shape(b)[:b_extra_ndims]
451  b_main_sh = array_ops.shape(b)[b_extra_ndims:]
452
453  # Permutation to put the extra dims at the end.
454  perm = (
455      np.concatenate(
456          (np.arange(b_extra_ndims, b.shape.ndims),
457           np.arange(0, b_extra_ndims)), 0))
458  b_extra_on_end = array_ops.transpose(b, perm=perm)
459
460  # Now squash this end into one long dim.
461  b_squashed_end = array_ops.reshape(
462      b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0))
463
464  def reshape_inv(y):
465    # Expand the extra dims hanging off the end, "b_extra_sh".
466    # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y
467    # Could have different batch dims than a and b, because of broadcasting.
468    y_extra_shape = array_ops.concat(
469        (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0)
470    y_extra_on_end = array_ops.reshape(y, y_extra_shape)
471    inverse_perm = np.argsort(perm)
472    return array_ops.transpose(y_extra_on_end, perm=inverse_perm)
473
474  return a, b_squashed_end, reshape_inv, still_need_to_transpose
475
476
477################################################################################
478# Helpers for hints.
479################################################################################
480
481
482def use_operator_or_provided_hint_unless_contradicting(
483    operator, hint_attr_name, provided_hint_value, message):
484  """Get combined hint in the case where operator.hint should equal hint.
485
486  Args:
487    operator:  LinearOperator that a meta-operator was initialized with.
488    hint_attr_name:  String name for the attribute.
489    provided_hint_value:  Bool or None. Value passed by user in initialization.
490    message:  Error message to print if hints contradict.
491
492  Returns:
493    True, False, or None.
494
495  Raises:
496    ValueError: If hints contradict.
497  """
498  op_hint = getattr(operator, hint_attr_name)
499  # pylint: disable=g-bool-id-comparison
500  if op_hint is False and provided_hint_value:
501    raise ValueError(message)
502  if op_hint and provided_hint_value is False:
503    raise ValueError(message)
504  if op_hint or provided_hint_value:
505    return True
506  if op_hint is False or provided_hint_value is False:
507    return False
508  # pylint: enable=g-bool-id-comparison
509  return None
510
511
512################################################################################
513# Utilities for blockwise operators.
514################################################################################
515
516
517def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
518  """Detect if input should be interpreted as a list of blocks."""
519  # Tuples and lists of length equal to the number of operators may be
520  # blockwise.
521  if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
522    # If the elements of the iterable are not nested, interpret the input as
523    # blockwise.
524    if not any(nest.is_nested(x) for x in arg):
525      return True
526    else:
527      arg_dims = [ops.convert_to_tensor_v2_with_dispatch(
528          x).shape[arg_split_dim] for x in arg]
529      self_dims = [dim.value for dim in block_dimensions]
530
531      # If none of the operator dimensions are known, interpret the input as
532      # blockwise if its matching dimensions are unequal.
533      if all(self_d is None for self_d in self_dims):
534
535        # A nested tuple/list with a single outermost element is not blockwise
536        if len(arg_dims) == 1:
537          return False
538        elif any(dim != arg_dims[0] for dim in arg_dims):
539          return True
540        else:
541          raise ValueError(
542              "Parsing of the input structure is ambiguous. Please input "
543              "a blockwise iterable of `Tensor`s or a single `Tensor`.")
544
545      # If input dimensions equal the respective (known) blockwise operator
546      # dimensions, then the input is blockwise.
547      if all(self_d == arg_d or self_d is None
548             for self_d, arg_d in zip(self_dims, arg_dims)):
549        return True
550
551      # If input dimensions equals are all equal, and are greater than or equal
552      # to the sum of the known operator dimensions, interpret the input as
553      # blockwise.
554      # input is not blockwise.
555      self_dim = sum(self_d for self_d in self_dims if self_d is not None)
556      if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim:
557        return False
558
559      # If none of these conditions is met, the input shape is mismatched.
560      raise ValueError("Input dimension does not match operator dimension.")
561  else:
562    return False
563
564
565def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1):
566  """Split `x` into blocks matching `operators`'s `domain_dimension`.
567
568  Specifically, if we have a blockwise lower-triangular matrix, with block
569  sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`,  this method splits `arg`
570  on `axis` into `J` tensors, whose shape at `axis` is `M_j`.
571
572  Args:
573    block_dims: Iterable of `TensorShapes`.
574    block_dims_fn: Callable returning an iterable of `Tensor`s.
575    arg: `Tensor`. `arg` is split into `J` tensors.
576    axis: Python `Integer` representing the axis to split `arg` on.
577
578  Returns:
579    A list of `Tensor`s.
580  """
581  block_sizes = [dim.value for dim in block_dims]
582  if any(d is None for d in block_sizes):
583    block_sizes = block_dims_fn()
584  return array_ops.split(arg, block_sizes, axis=axis)
585