1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for slicing in to a `LinearOperator`.""" 16 17import collections 18import functools 19import numpy as np 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.util import nest 25 26 27__all__ = ['batch_slice'] 28 29 30def _prefer_static_where(condition, x, y): 31 args = [condition, x, y] 32 constant_args = [tensor_util.constant_value(a) for a in args] 33 # Do this statically. 34 if all(arg is not None for arg in constant_args): 35 condition_, x_, y_ = constant_args 36 return np.where(condition_, x_, y_) 37 return array_ops.where(condition, x, y) 38 39 40def _broadcast_parameter_with_batch_shape( 41 param, param_ndims_to_matrix_ndims, batch_shape): 42 """Broadcasts `param` with the given batch shape, recursively.""" 43 if hasattr(param, 'batch_shape_tensor'): 44 # Recursively broadcast every parameter inside the operator. 45 override_dict = {} 46 for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long 47 sub_param = getattr(param, name) 48 override_dict[name] = nest.map_structure_up_to( 49 sub_param, functools.partial( 50 _broadcast_parameter_with_batch_shape, 51 batch_shape=batch_shape), sub_param, ndims) 52 parameters = dict(param.parameters, **override_dict) 53 return type(param)(**parameters) 54 55 base_shape = array_ops.concat( 56 [batch_shape, array_ops.ones( 57 [param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0) 58 return array_ops.broadcast_to( 59 param, 60 array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param))) 61 62 63def _sanitize_slices(slices, intended_shape, deficient_shape): 64 """Restricts slices to avoid overflowing size-1 (broadcast) dimensions. 65 66 Args: 67 slices: iterable of slices received by `__getitem__`. 68 intended_shape: int `Tensor` shape for which the slices were intended. 69 deficient_shape: int `Tensor` shape to which the slices will be applied. 70 Must have the same rank as `intended_shape`. 71 Returns: 72 sanitized_slices: Python `list` of slice objects. 73 """ 74 sanitized_slices = [] 75 idx = 0 76 for slc in slices: 77 if slc is Ellipsis: # Switch over to negative indexing. 78 if idx < 0: 79 raise ValueError('Found multiple `...` in slices {}'.format(slices)) 80 num_remaining_non_newaxis_slices = sum( 81 s is not array_ops.newaxis for s in slices[ 82 slices.index(Ellipsis) + 1:]) 83 idx = -num_remaining_non_newaxis_slices 84 elif slc is array_ops.newaxis: 85 pass 86 else: 87 is_broadcast = intended_shape[idx] > deficient_shape[idx] 88 if isinstance(slc, slice): 89 # Slices are denoted by start:stop:step. 90 start, stop, step = slc.start, slc.stop, slc.step 91 if start is not None: 92 start = _prefer_static_where(is_broadcast, 0, start) 93 if stop is not None: 94 stop = _prefer_static_where(is_broadcast, 1, stop) 95 if step is not None: 96 step = _prefer_static_where(is_broadcast, 1, step) 97 slc = slice(start, stop, step) 98 else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] 99 slc = _prefer_static_where(is_broadcast, 0, slc) 100 idx += 1 101 sanitized_slices.append(slc) 102 return sanitized_slices 103 104 105def _slice_single_param( 106 param, param_ndims_to_matrix_ndims, slices, batch_shape): 107 """Slices into the batch shape of a single parameter. 108 109 Args: 110 param: The original parameter to slice; either a `Tensor` or an object 111 with batch shape (LinearOperator). 112 param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for 113 inferring matrix shape of the `LinearOperator`. For non-Tensor 114 parameters, this is the number of this param's batch dimensions used by 115 the matrix shape of the parent object. 116 slices: iterable of slices received by `__getitem__`. 117 batch_shape: The parameterized object's batch shape `Tensor`. 118 119 Returns: 120 new_param: Instance of the same type as `param`, batch-sliced according to 121 `slices`. 122 """ 123 # Broadcast the parammeter to have full batch rank. 124 param = _broadcast_parameter_with_batch_shape( 125 param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape)) 126 127 if hasattr(param, 'batch_shape_tensor'): 128 param_batch_shape = param.batch_shape_tensor() 129 else: 130 param_batch_shape = array_ops.shape(param) 131 # Truncate by param_ndims_to_matrix_ndims 132 param_batch_rank = array_ops.size(param_batch_shape) 133 param_batch_shape = param_batch_shape[ 134 :(param_batch_rank - param_ndims_to_matrix_ndims)] 135 136 # At this point the param should have full batch rank, *unless* it's an 137 # atomic object like `tfb.Identity()` incapable of having any batch rank. 138 if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and 139 tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0): 140 return param 141 param_slices = _sanitize_slices( 142 slices, intended_shape=batch_shape, deficient_shape=param_batch_shape) 143 144 # Extend `param_slices` (which represents slicing into the 145 # parameter's batch shape) with the parameter's event ndims. For example, if 146 # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`. 147 if param_ndims_to_matrix_ndims > 0: 148 if Ellipsis not in [ 149 slc for slc in slices if not tensor_util.is_tensor(slc)]: 150 param_slices.append(Ellipsis) 151 param_slices += [slice(None)] * param_ndims_to_matrix_ndims 152 return param.__getitem__(tuple(param_slices)) 153 154 155def batch_slice(linop, params_overrides, slices): 156 """Slices `linop` along its batch dimensions. 157 158 Args: 159 linop: A `LinearOperator` instance. 160 params_overrides: A `dict` of parameter overrides. 161 slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple` 162 thereof. (e.g. the argument of a `__getitem__` method). 163 164 Returns: 165 new_linop: A batch-sliced `LinearOperator`. 166 """ 167 if not isinstance(slices, collections.abc.Sequence): 168 slices = (slices,) 169 if len(slices) == 1 and slices[0] is Ellipsis: 170 override_dict = {} 171 else: 172 batch_shape = linop.batch_shape_tensor() 173 override_dict = {} 174 for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long 175 param = getattr(linop, param_name) 176 # These represent optional `Tensor` parameters. 177 if param is not None: 178 override_dict[param_name] = nest.map_structure_up_to( 179 param, functools.partial( 180 _slice_single_param, slices=slices, batch_shape=batch_shape), 181 param, param_ndims_to_matrix_ndims) 182 override_dict.update(params_overrides) 183 parameters = dict(linop.parameters, **override_dict) 184 return type(linop)(**parameters) 185