1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators specific to data structures: list append, subscripts, etc.""" 16 17import collections 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import list_ops 26from tensorflow.python.ops import tensor_array_ops 27 28 29# TODO(mdan): Once control flow supports objects, repackage as a class. 30 31 32def new_list(iterable=None): 33 """The list constructor. 34 35 Args: 36 iterable: Optional elements to fill the list with. 37 38 Returns: 39 A list-like object. The exact return value depends on the initial elements. 40 """ 41 if iterable: 42 elements = tuple(iterable) 43 else: 44 elements = () 45 46 if elements: 47 # When the list contains elements, it is assumed to be a "Python" lvalue 48 # list. 49 return _py_list_new(elements) 50 return tf_tensor_list_new(elements) 51 52 53def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): 54 """Overload of new_list that stages a Tensor list creation.""" 55 elements = tuple(ops.convert_to_tensor(el) for el in elements) 56 57 all_dtypes = set(el.dtype for el in elements) 58 if len(all_dtypes) == 1: 59 inferred_dtype, = tuple(all_dtypes) 60 if element_dtype is not None and element_dtype != inferred_dtype: 61 raise ValueError( 62 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 63 element_dtype, elements, inferred_dtype)) 64 elif len(all_dtypes) > 1: 65 raise ValueError( 66 'TensorArray requires all elements to have the same dtype:' 67 ' {}'.format(elements)) 68 else: 69 if element_dtype is None: 70 raise ValueError('dtype is required to create an empty TensorArray') 71 72 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 73 if len(all_shapes) == 1: 74 inferred_shape, = tuple(all_shapes) 75 if element_shape is not None and element_shape != inferred_shape: 76 raise ValueError( 77 'incompatible shape; specified: {}, inferred from {}: {}'.format( 78 element_shape, elements, inferred_shape)) 79 elif len(all_shapes) > 1: 80 raise ValueError( 81 'TensorArray requires all elements to have the same shape:' 82 ' {}'.format(elements)) 83 # TODO(mdan): We may want to allow different shapes with infer_shape=False. 84 else: 85 inferred_shape = None 86 87 if element_dtype is None: 88 element_dtype = inferred_dtype 89 if element_shape is None: 90 element_shape = inferred_shape 91 92 l = tensor_array_ops.TensorArray( 93 dtype=element_dtype, 94 size=len(elements), 95 dynamic_size=True, 96 infer_shape=(element_shape is None), 97 element_shape=element_shape) 98 for i, el in enumerate(elements): 99 l = l.write(i, el) 100 return l 101 102 103def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): 104 """Overload of new_list that stages a Tensor list creation.""" 105 if tensor_util.is_tf_type(elements): 106 if element_shape is not None: 107 raise ValueError( 108 'element shape may not be specified when creating list from tensor') 109 element_shape = array_ops.shape(elements)[1:] 110 l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape) 111 return l 112 113 elements = tuple(ops.convert_to_tensor(el) for el in elements) 114 115 all_dtypes = set(el.dtype for el in elements) 116 if len(all_dtypes) == 1: 117 inferred_dtype = tuple(all_dtypes)[0] 118 if element_dtype is not None and element_dtype != inferred_dtype: 119 raise ValueError( 120 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 121 element_dtype, elements, inferred_dtype)) 122 elif all_dtypes: 123 # Heterogeneous lists are ok. 124 if element_dtype is not None: 125 raise ValueError( 126 'specified dtype {} is inconsistent with that of elements {}'.format( 127 element_dtype, elements)) 128 inferred_dtype = dtypes.variant 129 else: 130 inferred_dtype = dtypes.variant 131 132 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 133 if len(all_shapes) == 1: 134 inferred_shape = array_ops.shape(elements[0]) 135 if element_shape is not None and element_shape != inferred_shape: 136 raise ValueError( 137 'incompatible shape; specified: {}, inferred from {}: {}'.format( 138 element_shape, elements, inferred_shape)) 139 elif all_shapes: 140 # Heterogeneous lists are ok. 141 if element_shape is not None: 142 raise ValueError( 143 'specified shape {} is inconsistent with that of elements {}'.format( 144 element_shape, elements)) 145 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 146 else: 147 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 148 149 if element_dtype is None: 150 element_dtype = inferred_dtype 151 if element_shape is None: 152 element_shape = inferred_shape 153 154 element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32) 155 l = list_ops.empty_tensor_list( 156 element_shape=element_shape, element_dtype=element_dtype) 157 for el in elements: 158 l = list_ops.tensor_list_push_back(l, el) 159 return l 160 161 162def _py_list_new(elements): 163 """Overload of new_list that creates a Python list.""" 164 return list(elements) 165 166 167def list_append(list_, x): 168 """The list append function. 169 170 Note: it is unspecified where list_ will be mutated or not. If list_ is 171 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 172 list, it will be. In general, if the list is mutated then the return value 173 should point to the original entity. 174 175 Args: 176 list_: An entity that supports append semantics. 177 x: The element to append. 178 179 Returns: 180 Same as list_, after the append was performed. 181 182 Raises: 183 ValueError: if list_ is not of a known list-like type. 184 """ 185 if isinstance(list_, tensor_array_ops.TensorArray): 186 return _tf_tensorarray_append(list_, x) 187 elif tensor_util.is_tf_type(list_): 188 if list_.dtype == dtypes.variant: 189 return _tf_tensor_list_append(list_, x) 190 else: 191 raise ValueError( 192 'tensor lists are expected to be Tensors with dtype=tf.variant,' 193 ' instead found %s' % list_) 194 else: 195 return _py_list_append(list_, x) 196 197 198def _tf_tensor_list_append(list_, x): 199 """Overload of list_append that stages a Tensor list write.""" 200 def empty_list_of_elements_like_x(): 201 tensor_x = ops.convert_to_tensor(x) 202 return list_ops.empty_tensor_list( 203 element_shape=array_ops.shape(tensor_x), 204 element_dtype=tensor_x.dtype) 205 206 list_ = control_flow_ops.cond( 207 list_ops.tensor_list_length(list_) > 0, 208 lambda: list_, 209 empty_list_of_elements_like_x, 210 ) 211 return list_ops.tensor_list_push_back(list_, x) 212 213 214def _tf_tensorarray_append(list_, x): 215 """Overload of list_append that stages a TensorArray write.""" 216 return list_.write(list_.size(), x) 217 218 219def _py_list_append(list_, x): 220 """Overload of list_append that executes a Python list append.""" 221 # Revert to the original call. 222 list_.append(x) 223 return list_ 224 225 226class ListPopOpts( 227 collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): 228 pass 229 230 231def list_pop(list_, i, opts): 232 """The list pop function. 233 234 Note: it is unspecified where list_ will be mutated or not. If list_ is 235 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 236 list, it will be. In general, if the list is mutated then the return value 237 should point to the original entity. 238 239 Args: 240 list_: An entity that supports pop semantics. 241 i: Optional index to pop from. May be None. 242 opts: A ListPopOpts. 243 244 Returns: 245 Tuple (x, out_list_): 246 out_list_: same as list_, after the removal was performed. 247 x: the removed element value. 248 249 Raises: 250 ValueError: if list_ is not of a known list-like type or the operation is 251 not supported for that type. 252 """ 253 assert isinstance(opts, ListPopOpts) 254 255 if isinstance(list_, tensor_array_ops.TensorArray): 256 raise ValueError('TensorArray does not support item removal') 257 elif tensor_util.is_tf_type(list_): 258 if list_.dtype == dtypes.variant: 259 return _tf_tensor_list_pop(list_, i, opts) 260 else: 261 raise ValueError( 262 'tensor lists are expected to be Tensors with dtype=tf.variant,' 263 ' instead found %s' % list_) 264 else: 265 return _py_list_pop(list_, i) 266 267 268def _tf_tensor_list_pop(list_, i, opts): 269 """Overload of list_pop that stages a Tensor list pop.""" 270 if i is not None: 271 raise NotImplementedError('tensor lists only support removing from the end') 272 273 if opts.element_dtype is None: 274 raise ValueError('cannot pop from a list without knowing its element ' 275 'type; use set_element_type to annotate it') 276 if opts.element_shape is None: 277 raise ValueError('cannot pop from a list without knowing its element ' 278 'shape; use set_element_type to annotate it') 279 list_out, x = list_ops.tensor_list_pop_back( 280 list_, element_dtype=opts.element_dtype) 281 x.set_shape(opts.element_shape) 282 return list_out, x 283 284 285def _py_list_pop(list_, i): 286 """Overload of list_pop that executes a Python list append.""" 287 if i is None: 288 x = list_.pop() 289 else: 290 x = list_.pop(i) 291 return list_, x 292 293 294# TODO(mdan): Look into reducing duplication between all these containers. 295class ListStackOpts( 296 collections.namedtuple('ListStackOpts', 297 ('element_dtype', 'original_call'))): 298 pass 299 300 301def list_stack(list_, opts): 302 """The list stack function. 303 304 This does not have a direct correspondent in Python. The closest idiom to 305 this is tf.append or np.stack. It's different from those in the sense that it 306 accepts a Tensor list, rather than a list of tensors. It can also accept 307 TensorArray. When the target is anything else, the dispatcher will rely on 308 ctx.original_call for fallback. 309 310 Args: 311 list_: An entity that supports append semantics. 312 opts: A ListStackOpts object. 313 314 Returns: 315 The output of the stack operation, typically a Tensor. 316 """ 317 assert isinstance(opts, ListStackOpts) 318 319 if isinstance(list_, tensor_array_ops.TensorArray): 320 return _tf_tensorarray_stack(list_) 321 elif tensor_util.is_tf_type(list_): 322 if list_.dtype == dtypes.variant: 323 return _tf_tensor_list_stack(list_, opts) 324 else: 325 # No-op for primitive Tensor arguments. 326 return list_ 327 else: 328 return _py_list_stack(list_, opts) 329 330 331def _tf_tensorarray_stack(list_): 332 """Overload of list_stack that stages a TensorArray stack.""" 333 return list_.stack() 334 335 336def _tf_tensor_list_stack(list_, opts): 337 """Overload of list_stack that stages a Tensor list write.""" 338 if opts.element_dtype is None: 339 raise ValueError('cannot stack a list without knowing its element type;' 340 ' use set_element_type to annotate it') 341 return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) 342 343 344def _py_list_stack(list_, opts): 345 """Overload of list_stack that executes a Python list append.""" 346 # Revert to the original call. 347 return opts.original_call(list_) 348