xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/linalg/slicing.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for slicing in to a `LinearOperator`."""
16
17import collections
18import functools
19import numpy as np
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.util import nest
25
26
27__all__ = ['batch_slice']
28
29
30def _prefer_static_where(condition, x, y):
31  args = [condition, x, y]
32  constant_args = [tensor_util.constant_value(a) for a in args]
33  # Do this statically.
34  if all(arg is not None for arg in constant_args):
35    condition_, x_, y_ = constant_args
36    return np.where(condition_, x_, y_)
37  return array_ops.where(condition, x, y)
38
39
40def _broadcast_parameter_with_batch_shape(
41    param, param_ndims_to_matrix_ndims, batch_shape):
42  """Broadcasts `param` with the given batch shape, recursively."""
43  if hasattr(param, 'batch_shape_tensor'):
44    # Recursively broadcast every parameter inside the operator.
45    override_dict = {}
46    for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items():  # pylint:disable=protected-access,line-too-long
47      sub_param = getattr(param, name)
48      override_dict[name] = nest.map_structure_up_to(
49          sub_param, functools.partial(
50              _broadcast_parameter_with_batch_shape,
51              batch_shape=batch_shape), sub_param, ndims)
52    parameters = dict(param.parameters, **override_dict)
53    return type(param)(**parameters)
54
55  base_shape = array_ops.concat(
56      [batch_shape, array_ops.ones(
57          [param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0)
58  return array_ops.broadcast_to(
59      param,
60      array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param)))
61
62
63def _sanitize_slices(slices, intended_shape, deficient_shape):
64  """Restricts slices to avoid overflowing size-1 (broadcast) dimensions.
65
66  Args:
67    slices: iterable of slices received by `__getitem__`.
68    intended_shape: int `Tensor` shape for which the slices were intended.
69    deficient_shape: int `Tensor` shape to which the slices will be applied.
70      Must have the same rank as `intended_shape`.
71  Returns:
72    sanitized_slices: Python `list` of slice objects.
73  """
74  sanitized_slices = []
75  idx = 0
76  for slc in slices:
77    if slc is Ellipsis:  # Switch over to negative indexing.
78      if idx < 0:
79        raise ValueError('Found multiple `...` in slices {}'.format(slices))
80      num_remaining_non_newaxis_slices = sum(
81          s is not array_ops.newaxis for s in slices[
82              slices.index(Ellipsis) + 1:])
83      idx = -num_remaining_non_newaxis_slices
84    elif slc is array_ops.newaxis:
85      pass
86    else:
87      is_broadcast = intended_shape[idx] > deficient_shape[idx]
88      if isinstance(slc, slice):
89        # Slices are denoted by start:stop:step.
90        start, stop, step = slc.start, slc.stop, slc.step
91        if start is not None:
92          start = _prefer_static_where(is_broadcast, 0, start)
93        if stop is not None:
94          stop = _prefer_static_where(is_broadcast, 1, stop)
95        if step is not None:
96          step = _prefer_static_where(is_broadcast, 1, step)
97        slc = slice(start, stop, step)
98      else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
99        slc = _prefer_static_where(is_broadcast, 0, slc)
100      idx += 1
101    sanitized_slices.append(slc)
102  return sanitized_slices
103
104
105def _slice_single_param(
106    param, param_ndims_to_matrix_ndims, slices, batch_shape):
107  """Slices into the batch shape of a single parameter.
108
109  Args:
110    param: The original parameter to slice; either a `Tensor` or an object
111      with batch shape (LinearOperator).
112    param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
113      inferring matrix shape of the `LinearOperator`. For non-Tensor
114      parameters, this is the number of this param's batch dimensions used by
115      the matrix shape of the parent object.
116    slices: iterable of slices received by `__getitem__`.
117    batch_shape: The parameterized object's batch shape `Tensor`.
118
119  Returns:
120    new_param: Instance of the same type as `param`, batch-sliced according to
121      `slices`.
122  """
123  # Broadcast the parammeter to have full batch rank.
124  param = _broadcast_parameter_with_batch_shape(
125      param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))
126
127  if hasattr(param, 'batch_shape_tensor'):
128    param_batch_shape = param.batch_shape_tensor()
129  else:
130    param_batch_shape = array_ops.shape(param)
131  # Truncate by param_ndims_to_matrix_ndims
132  param_batch_rank = array_ops.size(param_batch_shape)
133  param_batch_shape = param_batch_shape[
134      :(param_batch_rank - param_ndims_to_matrix_ndims)]
135
136  # At this point the param should have full batch rank, *unless* it's an
137  # atomic object like `tfb.Identity()` incapable of having any batch rank.
138  if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and
139      tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0):
140    return param
141  param_slices = _sanitize_slices(
142      slices, intended_shape=batch_shape, deficient_shape=param_batch_shape)
143
144  # Extend `param_slices` (which represents slicing into the
145  # parameter's batch shape) with the parameter's event ndims. For example, if
146  # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
147  if param_ndims_to_matrix_ndims > 0:
148    if Ellipsis not in [
149        slc for slc in slices if not tensor_util.is_tensor(slc)]:
150      param_slices.append(Ellipsis)
151    param_slices += [slice(None)] * param_ndims_to_matrix_ndims
152  return param.__getitem__(tuple(param_slices))
153
154
155def batch_slice(linop, params_overrides, slices):
156  """Slices `linop` along its batch dimensions.
157
158  Args:
159    linop: A `LinearOperator` instance.
160    params_overrides: A `dict` of parameter overrides.
161    slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple`
162      thereof. (e.g. the argument of a `__getitem__` method).
163
164  Returns:
165    new_linop: A batch-sliced `LinearOperator`.
166  """
167  if not isinstance(slices, collections.abc.Sequence):
168    slices = (slices,)
169  if len(slices) == 1 and slices[0] is Ellipsis:
170    override_dict = {}
171  else:
172    batch_shape = linop.batch_shape_tensor()
173    override_dict = {}
174    for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items():  # pylint:disable=protected-access,line-too-long
175      param = getattr(linop, param_name)
176      # These represent optional `Tensor` parameters.
177      if param is not None:
178        override_dict[param_name] = nest.map_structure_up_to(
179            param, functools.partial(
180                _slice_single_param, slices=slices, batch_shape=batch_shape),
181            param, param_ndims_to_matrix_ndims)
182  override_dict.update(params_overrides)
183  parameters = dict(linop.parameters, **override_dict)
184  return type(linop)(**parameters)
185