xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/operators/data_structures.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators specific to data structures: list append, subscripts, etc."""
16
17import collections
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import list_ops
26from tensorflow.python.ops import tensor_array_ops
27
28
29# TODO(mdan): Once control flow supports objects, repackage as a class.
30
31
32def new_list(iterable=None):
33  """The list constructor.
34
35  Args:
36    iterable: Optional elements to fill the list with.
37
38  Returns:
39    A list-like object. The exact return value depends on the initial elements.
40  """
41  if iterable:
42    elements = tuple(iterable)
43  else:
44    elements = ()
45
46  if elements:
47    # When the list contains elements, it is assumed to be a "Python" lvalue
48    # list.
49    return _py_list_new(elements)
50  return tf_tensor_list_new(elements)
51
52
53def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
54  """Overload of new_list that stages a Tensor list creation."""
55  elements = tuple(ops.convert_to_tensor(el) for el in elements)
56
57  all_dtypes = set(el.dtype for el in elements)
58  if len(all_dtypes) == 1:
59    inferred_dtype, = tuple(all_dtypes)
60    if element_dtype is not None and element_dtype != inferred_dtype:
61      raise ValueError(
62          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
63              element_dtype, elements, inferred_dtype))
64  elif len(all_dtypes) > 1:
65    raise ValueError(
66        'TensorArray requires all elements to have the same dtype:'
67        ' {}'.format(elements))
68  else:
69    if element_dtype is None:
70      raise ValueError('dtype is required to create an empty TensorArray')
71
72  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
73  if len(all_shapes) == 1:
74    inferred_shape, = tuple(all_shapes)
75    if element_shape is not None and element_shape != inferred_shape:
76      raise ValueError(
77          'incompatible shape; specified: {}, inferred from {}: {}'.format(
78              element_shape, elements, inferred_shape))
79  elif len(all_shapes) > 1:
80    raise ValueError(
81        'TensorArray requires all elements to have the same shape:'
82        ' {}'.format(elements))
83    # TODO(mdan): We may want to allow different shapes with infer_shape=False.
84  else:
85    inferred_shape = None
86
87  if element_dtype is None:
88    element_dtype = inferred_dtype
89  if element_shape is None:
90    element_shape = inferred_shape
91
92  l = tensor_array_ops.TensorArray(
93      dtype=element_dtype,
94      size=len(elements),
95      dynamic_size=True,
96      infer_shape=(element_shape is None),
97      element_shape=element_shape)
98  for i, el in enumerate(elements):
99    l = l.write(i, el)
100  return l
101
102
103def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
104  """Overload of new_list that stages a Tensor list creation."""
105  if tensor_util.is_tf_type(elements):
106    if element_shape is not None:
107      raise ValueError(
108          'element shape may not be specified when creating list from tensor')
109    element_shape = array_ops.shape(elements)[1:]
110    l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
111    return l
112
113  elements = tuple(ops.convert_to_tensor(el) for el in elements)
114
115  all_dtypes = set(el.dtype for el in elements)
116  if len(all_dtypes) == 1:
117    inferred_dtype = tuple(all_dtypes)[0]
118    if element_dtype is not None and element_dtype != inferred_dtype:
119      raise ValueError(
120          'incompatible dtype; specified: {}, inferred from {}: {}'.format(
121              element_dtype, elements, inferred_dtype))
122  elif all_dtypes:
123    # Heterogeneous lists are ok.
124    if element_dtype is not None:
125      raise ValueError(
126          'specified dtype {} is inconsistent with that of elements {}'.format(
127              element_dtype, elements))
128    inferred_dtype = dtypes.variant
129  else:
130    inferred_dtype = dtypes.variant
131
132  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
133  if len(all_shapes) == 1:
134    inferred_shape = array_ops.shape(elements[0])
135    if element_shape is not None and element_shape != inferred_shape:
136      raise ValueError(
137          'incompatible shape; specified: {}, inferred from {}: {}'.format(
138              element_shape, elements, inferred_shape))
139  elif all_shapes:
140    # Heterogeneous lists are ok.
141    if element_shape is not None:
142      raise ValueError(
143          'specified shape {} is inconsistent with that of elements {}'.format(
144              element_shape, elements))
145    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
146  else:
147    inferred_shape = constant_op.constant(-1)  # unknown shape, by convention
148
149  if element_dtype is None:
150    element_dtype = inferred_dtype
151  if element_shape is None:
152    element_shape = inferred_shape
153
154  element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
155  l = list_ops.empty_tensor_list(
156      element_shape=element_shape, element_dtype=element_dtype)
157  for el in elements:
158    l = list_ops.tensor_list_push_back(l, el)
159  return l
160
161
162def _py_list_new(elements):
163  """Overload of new_list that creates a Python list."""
164  return list(elements)
165
166
167def list_append(list_, x):
168  """The list append function.
169
170  Note: it is unspecified where list_ will be mutated or not. If list_ is
171  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
172  list, it will be. In general, if the list is mutated then the return value
173  should point to the original entity.
174
175  Args:
176    list_: An entity that supports append semantics.
177    x: The element to append.
178
179  Returns:
180    Same as list_, after the append was performed.
181
182  Raises:
183    ValueError: if list_ is not of a known list-like type.
184  """
185  if isinstance(list_, tensor_array_ops.TensorArray):
186    return _tf_tensorarray_append(list_, x)
187  elif tensor_util.is_tf_type(list_):
188    if list_.dtype == dtypes.variant:
189      return _tf_tensor_list_append(list_, x)
190    else:
191      raise ValueError(
192          'tensor lists are expected to be Tensors with dtype=tf.variant,'
193          ' instead found %s' % list_)
194  else:
195    return _py_list_append(list_, x)
196
197
198def _tf_tensor_list_append(list_, x):
199  """Overload of list_append that stages a Tensor list write."""
200  def empty_list_of_elements_like_x():
201    tensor_x = ops.convert_to_tensor(x)
202    return list_ops.empty_tensor_list(
203        element_shape=array_ops.shape(tensor_x),
204        element_dtype=tensor_x.dtype)
205
206  list_ = control_flow_ops.cond(
207      list_ops.tensor_list_length(list_) > 0,
208      lambda: list_,
209      empty_list_of_elements_like_x,
210  )
211  return list_ops.tensor_list_push_back(list_, x)
212
213
214def _tf_tensorarray_append(list_, x):
215  """Overload of list_append that stages a TensorArray write."""
216  return list_.write(list_.size(), x)
217
218
219def _py_list_append(list_, x):
220  """Overload of list_append that executes a Python list append."""
221  # Revert to the original call.
222  list_.append(x)
223  return list_
224
225
226class ListPopOpts(
227    collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
228  pass
229
230
231def list_pop(list_, i, opts):
232  """The list pop function.
233
234  Note: it is unspecified where list_ will be mutated or not. If list_ is
235  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
236  list, it will be. In general, if the list is mutated then the return value
237  should point to the original entity.
238
239  Args:
240    list_: An entity that supports pop semantics.
241    i: Optional index to pop from. May be None.
242    opts: A ListPopOpts.
243
244  Returns:
245    Tuple (x, out_list_):
246      out_list_: same as list_, after the removal was performed.
247      x: the removed element value.
248
249  Raises:
250    ValueError: if list_ is not of a known list-like type or the operation is
251    not supported for that type.
252  """
253  assert isinstance(opts, ListPopOpts)
254
255  if isinstance(list_, tensor_array_ops.TensorArray):
256    raise ValueError('TensorArray does not support item removal')
257  elif tensor_util.is_tf_type(list_):
258    if list_.dtype == dtypes.variant:
259      return _tf_tensor_list_pop(list_, i, opts)
260    else:
261      raise ValueError(
262          'tensor lists are expected to be Tensors with dtype=tf.variant,'
263          ' instead found %s' % list_)
264  else:
265    return _py_list_pop(list_, i)
266
267
268def _tf_tensor_list_pop(list_, i, opts):
269  """Overload of list_pop that stages a Tensor list pop."""
270  if i is not None:
271    raise NotImplementedError('tensor lists only support removing from the end')
272
273  if opts.element_dtype is None:
274    raise ValueError('cannot pop from a list without knowing its element '
275                     'type; use set_element_type to annotate it')
276  if opts.element_shape is None:
277    raise ValueError('cannot pop from a list without knowing its element '
278                     'shape; use set_element_type to annotate it')
279  list_out, x = list_ops.tensor_list_pop_back(
280      list_, element_dtype=opts.element_dtype)
281  x.set_shape(opts.element_shape)
282  return list_out, x
283
284
285def _py_list_pop(list_, i):
286  """Overload of list_pop that executes a Python list append."""
287  if i is None:
288    x = list_.pop()
289  else:
290    x = list_.pop(i)
291  return list_, x
292
293
294# TODO(mdan): Look into reducing duplication between all these containers.
295class ListStackOpts(
296    collections.namedtuple('ListStackOpts',
297                           ('element_dtype', 'original_call'))):
298  pass
299
300
301def list_stack(list_, opts):
302  """The list stack function.
303
304  This does not have a direct correspondent in Python. The closest idiom to
305  this is tf.append or np.stack. It's different from those in the sense that it
306  accepts a Tensor list, rather than a list of tensors. It can also accept
307  TensorArray. When the target is anything else, the dispatcher will rely on
308  ctx.original_call for fallback.
309
310  Args:
311    list_: An entity that supports append semantics.
312    opts: A ListStackOpts object.
313
314  Returns:
315    The output of the stack operation, typically a Tensor.
316  """
317  assert isinstance(opts, ListStackOpts)
318
319  if isinstance(list_, tensor_array_ops.TensorArray):
320    return _tf_tensorarray_stack(list_)
321  elif tensor_util.is_tf_type(list_):
322    if list_.dtype == dtypes.variant:
323      return _tf_tensor_list_stack(list_, opts)
324    else:
325      # No-op for primitive Tensor arguments.
326      return list_
327  else:
328    return _py_list_stack(list_, opts)
329
330
331def _tf_tensorarray_stack(list_):
332  """Overload of list_stack that stages a TensorArray stack."""
333  return list_.stack()
334
335
336def _tf_tensor_list_stack(list_, opts):
337  """Overload of list_stack that stages a Tensor list write."""
338  if opts.element_dtype is None:
339    raise ValueError('cannot stack a list without knowing its element type;'
340                     ' use set_element_type to annotate it')
341  return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
342
343
344def _py_list_stack(list_, opts):
345  """Overload of list_stack that executes a Python list append."""
346  # Revert to the original call.
347  return opts.original_call(list_)
348