xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/data_flow_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Data Flow Operations."""
16# pylint: disable=g-bad-name
17import functools
18import hashlib
19import threading
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes as _dtypes
23from tensorflow.python.framework import indexed_slices
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import random_seed
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.lib.io import python_io
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_data_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import resource_variable_ops
34# go/tf-wildcard-import
35# pylint: disable=wildcard-import
36from tensorflow.python.ops.gen_data_flow_ops import *
37from tensorflow.python.util import deprecation
38from tensorflow.python.util.compat import collections_abc
39from tensorflow.python.util.tf_export import tf_export
40
41# pylint: enable=wildcard-import
42
43
44def _as_type_list(dtypes):
45  """Convert dtypes to a list of types."""
46  assert dtypes is not None
47  if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
48    # We have a single type.
49    return [dtypes]
50  else:
51    # We have a list or tuple of types.
52    return list(dtypes)
53
54
55def _as_shape_list(shapes,
56                   dtypes,
57                   unknown_dim_allowed=False,
58                   unknown_rank_allowed=False):
59  """Convert shapes to a list of tuples of int (or None)."""
60  del dtypes
61  if unknown_dim_allowed:
62    if (not isinstance(shapes, collections_abc.Sequence) or not shapes or
63        any(shape is None or isinstance(shape, int) for shape in shapes)):
64      raise ValueError(
65          "When providing partial shapes, a list of shapes must be provided.")
66  if shapes is None:
67    return None
68  if isinstance(shapes, tensor_shape.TensorShape):
69    shapes = [shapes]
70  if not isinstance(shapes, (tuple, list)):
71    raise TypeError(
72        "Shapes must be a TensorShape or a list or tuple of TensorShapes, "
73        f"got {type(shapes)} instead.")
74  if all(shape is None or isinstance(shape, int) for shape in shapes):
75    # We have a single shape.
76    shapes = [shapes]
77  shapes = [tensor_shape.as_shape(shape) for shape in shapes]
78  if not unknown_dim_allowed:
79    if any(not shape.is_fully_defined() for shape in shapes):
80      raise ValueError(f"All shapes must be fully defined: {shapes}")
81  if not unknown_rank_allowed:
82    if any(shape.dims is None for shape in shapes):
83      raise ValueError(f"All shapes must have a defined rank: {shapes}")
84
85  return shapes
86
87
88def _as_name_list(names, dtypes):
89  if names is None:
90    return None
91  if not isinstance(names, (list, tuple)):
92    names = [names]
93  if len(names) != len(dtypes):
94    raise ValueError("List of names must have the same length as the list "
95                     f"of dtypes, received len(names)={len(names)},"
96                     f"len(dtypes)={len(dtypes)}")
97  return list(names)
98
99
100def _shape_common(s1, s2):
101  """The greatest lower bound (ordered by specificity) TensorShape."""
102  s1 = tensor_shape.TensorShape(s1)
103  s2 = tensor_shape.TensorShape(s2)
104  if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
105    return tensor_shape.unknown_shape()
106  d = [
107      d1 if d1 is not None and d1 == d2 else None
108      for (d1, d2) in zip(s1.as_list(), s2.as_list())
109  ]
110  return tensor_shape.TensorShape(d)
111
112
113# pylint: disable=protected-access
114@tf_export("queue.QueueBase",
115           v1=["queue.QueueBase", "io.QueueBase", "QueueBase"])
116@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"])
117class QueueBase:
118  """Base class for queue implementations.
119
120  A queue is a TensorFlow data structure that stores tensors across
121  multiple steps, and exposes operations that enqueue and dequeue
122  tensors.
123
124  Each queue element is a tuple of one or more tensors, where each
125  tuple component has a static dtype, and may have a static shape. The
126  queue implementations support versions of enqueue and dequeue that
127  handle single elements, versions that support enqueuing and
128  dequeuing a batch of elements at once.
129
130  See `tf.queue.FIFOQueue` and
131  `tf.queue.RandomShuffleQueue` for concrete
132  implementations of this class, and instructions on how to create
133  them.
134  """
135
136  def __init__(self, dtypes, shapes, names, queue_ref):
137    """Constructs a queue object from a queue reference.
138
139    The two optional lists, `shapes` and `names`, must be of the same length
140    as `dtypes` if provided.  The values at a given index `i` indicate the
141    shape and name to use for the corresponding queue component in `dtypes`.
142
143    Args:
144      dtypes:  A list of types.  The length of dtypes must equal the number
145        of tensors in each element.
146      shapes: Constraints on the shapes of tensors in an element:
147        A list of shape tuples or None. This list is the same length
148        as dtypes.  If the shape of any tensors in the element are constrained,
149        all must be; shapes can be None if the shapes should not be constrained.
150      names: Optional list of names.  If provided, the `enqueue()` and
151        `dequeue()` methods will use dictionaries with these names as keys.
152        Must be None or a list or tuple of the same length as `dtypes`.
153      queue_ref: The queue reference, i.e. the output of the queue op.
154
155    Raises:
156      ValueError: If one of the arguments is invalid.
157    """
158    self._dtypes = dtypes
159    if shapes is not None:
160      if len(shapes) != len(dtypes):
161        raise ValueError("Queue shapes must have the same length as dtypes, "
162                         f"received len(shapes)={len(shapes)}, "
163                         f"len(dtypes)={len(dtypes)}")
164      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
165    else:
166      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
167    if names is not None:
168      if len(names) != len(dtypes):
169        raise ValueError("Queue names must have the same length as dtypes,"
170                         f"received len(names)={len(names)},"
171                         f"len {len(dtypes)}")
172      self._names = names
173    else:
174      self._names = None
175    self._queue_ref = queue_ref
176    if isinstance(queue_ref, ops.EagerTensor):
177      if context.context().scope_name:
178        self._name = context.context().scope_name
179      else:
180        self._name = "Empty"
181      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
182          queue_ref, None)
183    else:
184      self._name = self._queue_ref.op.name.split("/")[-1]
185
186  @staticmethod
187  def from_list(index, queues):
188    """Create a queue using the queue reference from `queues[index]`.
189
190    Args:
191      index: An integer scalar tensor that determines the input that gets
192        selected.
193      queues: A list of `QueueBase` objects.
194
195    Returns:
196      A `QueueBase` object.
197
198    Raises:
199      TypeError: When `queues` is not a list of `QueueBase` objects,
200        or when the data types of `queues` are not all the same.
201    """
202    if ((not queues) or (not isinstance(queues, list)) or
203        (not all(isinstance(x, QueueBase) for x in queues))):
204      raise TypeError("A list of queues expected")
205
206    dtypes = queues[0].dtypes
207    if not all(dtypes == q.dtypes for q in queues[1:]):
208      raise TypeError("Queues do not have matching component dtypes.")
209
210    names = queues[0].names
211    if not all(names == q.names for q in queues[1:]):
212      raise TypeError("Queues do not have matching component names.")
213
214    queue_shapes = [q.shapes for q in queues]
215    reduced_shapes = [
216        functools.reduce(_shape_common, s) for s in zip(*queue_shapes)
217    ]
218
219    queue_refs = array_ops.stack([x.queue_ref for x in queues])
220    selected_queue = array_ops.gather(queue_refs, index)
221    return QueueBase(
222        dtypes=dtypes,
223        shapes=reduced_shapes,
224        names=names,
225        queue_ref=selected_queue)
226
227  @property
228  def queue_ref(self):
229    """The underlying queue reference."""
230    return self._queue_ref
231
232  @property
233  def name(self):
234    """The name of the underlying queue."""
235    if context.executing_eagerly():
236      return self._name
237    return self._queue_ref.op.name
238
239  @property
240  def dtypes(self):
241    """The list of dtypes for each component of a queue element."""
242    return self._dtypes
243
244  @property
245  def shapes(self):
246    """The list of shapes for each component of a queue element."""
247    return self._shapes
248
249  @property
250  def names(self):
251    """The list of names for each component of a queue element."""
252    return self._names
253
254  def _check_enqueue_dtypes(self, vals):
255    """Validate and convert `vals` to a list of `Tensor`s.
256
257    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
258    dictionary with tensor values.
259
260    If it is a dictionary, the queue must have been constructed with a
261    `names` attribute and the dictionary keys must match the queue names.
262    If the queue was constructed with a `names` attribute, `vals` must
263    be a dictionary.
264
265    Args:
266      vals: A tensor, a list or tuple of tensors, or a dictionary..
267
268    Returns:
269      A list of `Tensor` objects.
270
271    Raises:
272      ValueError: If `vals` is invalid.
273    """
274    if isinstance(vals, dict):
275      if not self._names:
276        raise ValueError("Queue must have names to enqueue a dictionary")
277      if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
278        raise ValueError("Keys in dictionary to enqueue do not match "
279                         f"names of Queue.  Dictionary: {sorted(vals.keys())},"
280                         f"Queue: {sorted(self._names)}")
281      # The order of values in `self._names` indicates the order in which the
282      # tensors in the dictionary `vals` must be listed.
283      vals = [vals[k] for k in self._names]
284    else:
285      if self._names:
286        raise ValueError("You must enqueue a dictionary in a Queue with names")
287      if not isinstance(vals, (list, tuple)):
288        vals = [vals]
289
290    tensors = []
291    for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
292      tensors.append(
293          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
294
295    return tensors
296
297  def _scope_vals(self, vals):
298    """Return a list of values to pass to `name_scope()`.
299
300    Args:
301      vals: A tensor, a list or tuple of tensors, or a dictionary.
302
303    Returns:
304      The values in vals as a list.
305    """
306    if isinstance(vals, (list, tuple)):
307      return vals
308    elif isinstance(vals, dict):
309      return vals.values()
310    else:
311      return [vals]
312
313  def enqueue(self, vals, name=None):
314    """Enqueues one element to this queue.
315
316    If the queue is full when this operation executes, it will block
317    until the element has been enqueued.
318
319    At runtime, this operation may raise an error if the queue is
320    `tf.QueueBase.close` before or during its execution. If the
321    queue is closed before this operation runs,
322    `tf.errors.CancelledError` will be raised. If this operation is
323    blocked, and either (i) the queue is closed by a close operation
324    with `cancel_pending_enqueues=True`, or (ii) the session is
325    `tf.Session.close`,
326    `tf.errors.CancelledError` will be raised.
327
328    Args:
329      vals: A tensor, a list or tuple of tensors, or a dictionary containing
330        the values to enqueue.
331      name: A name for the operation (optional).
332
333    Returns:
334      The operation that enqueues a new tuple of tensors to the queue.
335    """
336    with ops.name_scope(name, "%s_enqueue" % self._name,
337                        self._scope_vals(vals)) as scope:
338      vals = self._check_enqueue_dtypes(vals)
339
340      # NOTE(mrry): Not using a shape function because we need access to
341      # the `QueueBase` object.
342      for val, shape in zip(vals, self._shapes):
343        val.get_shape().assert_is_compatible_with(shape)
344
345      if self._queue_ref.dtype == _dtypes.resource:
346        return gen_data_flow_ops.queue_enqueue_v2(
347            self._queue_ref, vals, name=scope)
348      else:
349        return gen_data_flow_ops.queue_enqueue(
350            self._queue_ref, vals, name=scope)
351
352  def enqueue_many(self, vals, name=None):
353    """Enqueues zero or more elements to this queue.
354
355    This operation slices each component tensor along the 0th dimension to
356    make multiple queue elements. All of the tensors in `vals` must have the
357    same size in the 0th dimension.
358
359    If the queue is full when this operation executes, it will block
360    until all of the elements have been enqueued.
361
362    At runtime, this operation may raise an error if the queue is
363    `tf.QueueBase.close` before or during its execution. If the
364    queue is closed before this operation runs,
365    `tf.errors.CancelledError` will be raised. If this operation is
366    blocked, and either (i) the queue is closed by a close operation
367    with `cancel_pending_enqueues=True`, or (ii) the session is
368    `tf.Session.close`,
369    `tf.errors.CancelledError` will be raised.
370
371    Args:
372      vals: A tensor, a list or tuple of tensors, or a dictionary
373        from which the queue elements are taken.
374      name: A name for the operation (optional).
375
376    Returns:
377      The operation that enqueues a batch of tuples of tensors to the queue.
378    """
379    with ops.name_scope(name, "%s_EnqueueMany" % self._name,
380                        self._scope_vals(vals)) as scope:
381      vals = self._check_enqueue_dtypes(vals)
382
383      # NOTE(mrry): Not using a shape function because we need access to
384      # the `QueueBase` object.
385      # NOTE(fchollet): the code that follow is verbose because it needs to be
386      # compatible with both TF v1 TensorShape behavior and TF v2 behavior.
387      batch_dim = tensor_shape.dimension_value(
388          vals[0].get_shape().with_rank_at_least(1)[0])
389      batch_dim = tensor_shape.Dimension(batch_dim)
390      for val, shape in zip(vals, self._shapes):
391        val_batch_dim = tensor_shape.dimension_value(
392            val.get_shape().with_rank_at_least(1)[0])
393        val_batch_dim = tensor_shape.Dimension(val_batch_dim)
394        batch_dim = batch_dim.merge_with(val_batch_dim)
395        val.get_shape()[1:].assert_is_compatible_with(shape)
396
397      return gen_data_flow_ops.queue_enqueue_many_v2(
398          self._queue_ref, vals, name=scope)
399
400  def _dequeue_return_value(self, tensors):
401    """Return the value to return from a dequeue op.
402
403    If the queue has names, return a dictionary with the
404    names as keys.  Otherwise return either a single tensor
405    or a list of tensors depending on the length of `tensors`.
406
407    Args:
408      tensors: List of tensors from the dequeue op.
409
410    Returns:
411      A single tensor, a list of tensors, or a dictionary
412      of tensors.
413    """
414    if self._names:
415      # The returned values in `tensors` are in the same order as
416      # the names in `self._names`.
417      return {n: tensors[i] for i, n in enumerate(self._names)}
418    elif len(tensors) == 1:
419      return tensors[0]
420    else:
421      return tensors
422
423  def dequeue(self, name=None):
424    """Dequeues one element from this queue.
425
426    If the queue is empty when this operation executes, it will block
427    until there is an element to dequeue.
428
429    At runtime, this operation may raise an error if the queue is
430    `tf.QueueBase.close` before or during its execution. If the
431    queue is closed, the queue is empty, and there are no pending
432    enqueue operations that can fulfill this request,
433    `tf.errors.OutOfRangeError` will be raised. If the session is
434    `tf.Session.close`,
435    `tf.errors.CancelledError` will be raised.
436
437    Args:
438      name: A name for the operation (optional).
439
440    Returns:
441      The tuple of tensors that was dequeued.
442    """
443    if name is None:
444      name = "%s_Dequeue" % self._name
445    if self._queue_ref.dtype == _dtypes.resource:
446      ret = gen_data_flow_ops.queue_dequeue_v2(
447          self._queue_ref, self._dtypes, name=name)
448    else:
449      ret = gen_data_flow_ops.queue_dequeue(
450          self._queue_ref, self._dtypes, name=name)
451
452    # NOTE(mrry): Not using a shape function because we need access to
453    # the `QueueBase` object.
454    if not context.executing_eagerly():
455      op = ret[0].op
456      for output, shape in zip(op.values(), self._shapes):
457        output.set_shape(shape)
458
459    return self._dequeue_return_value(ret)
460
461  def dequeue_many(self, n, name=None):
462    """Dequeues and concatenates `n` elements from this queue.
463
464    This operation concatenates queue-element component tensors along
465    the 0th dimension to make a single component tensor.  All of the
466    components in the dequeued tuple will have size `n` in the 0th dimension.
467
468    If the queue is closed and there are less than `n` elements left, then an
469    `OutOfRange` exception is raised.
470
471    At runtime, this operation may raise an error if the queue is
472    `tf.QueueBase.close` before or during its execution. If the
473    queue is closed, the queue contains fewer than `n` elements, and
474    there are no pending enqueue operations that can fulfill this
475    request, `tf.errors.OutOfRangeError` will be raised. If the
476    session is `tf.Session.close`,
477    `tf.errors.CancelledError` will be raised.
478
479    Args:
480      n: A scalar `Tensor` containing the number of elements to dequeue.
481      name: A name for the operation (optional).
482
483    Returns:
484      The list of concatenated tensors that was dequeued.
485    """
486    if name is None:
487      name = "%s_DequeueMany" % self._name
488
489    ret = gen_data_flow_ops.queue_dequeue_many_v2(
490        self._queue_ref, n=n, component_types=self._dtypes, name=name)
491
492    # NOTE(mrry): Not using a shape function because we need access to
493    # the Queue object.
494    if not context.executing_eagerly():
495      op = ret[0].op
496      batch_dim = tensor_shape.Dimension(
497          tensor_util.constant_value(op.inputs[1]))
498      for output, shape in zip(op.values(), self._shapes):
499        output.set_shape(
500            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
501
502    return self._dequeue_return_value(ret)
503
504  def dequeue_up_to(self, n, name=None):
505    """Dequeues and concatenates `n` elements from this queue.
506
507    **Note** This operation is not supported by all queues.  If a queue does not
508    support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
509
510    This operation concatenates queue-element component tensors along
511    the 0th dimension to make a single component tensor. If the queue
512    has not been closed, all of the components in the dequeued tuple
513    will have size `n` in the 0th dimension.
514
515    If the queue is closed and there are more than `0` but fewer than
516    `n` elements remaining, then instead of raising a
517    `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`,
518    less than `n` elements are returned immediately.  If the queue is
519    closed and there are `0` elements left in the queue, then a
520    `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
521    Otherwise the behavior is identical to `dequeue_many`.
522
523    Args:
524      n: A scalar `Tensor` containing the number of elements to dequeue.
525      name: A name for the operation (optional).
526
527    Returns:
528      The tuple of concatenated tensors that was dequeued.
529    """
530    if name is None:
531      name = "%s_DequeueUpTo" % self._name
532
533    ret = gen_data_flow_ops.queue_dequeue_up_to_v2(
534        self._queue_ref, n=n, component_types=self._dtypes, name=name)
535
536    # NOTE(mrry): Not using a shape function because we need access to
537    # the Queue object.
538    if not context.executing_eagerly():
539      op = ret[0].op
540      for output, shape in zip(op.values(), self._shapes):
541        output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
542
543    return self._dequeue_return_value(ret)
544
545  def close(self, cancel_pending_enqueues=False, name=None):
546    """Closes this queue.
547
548    This operation signals that no more elements will be enqueued in
549    the given queue. Subsequent `enqueue` and `enqueue_many`
550    operations will fail. Subsequent `dequeue` and `dequeue_many`
551    operations will continue to succeed if sufficient elements remain
552    in the queue. Subsequently dequeue and dequeue_many operations
553    that would otherwise block waiting for more elements (if close
554    hadn't been called) will now fail immediately.
555
556    If `cancel_pending_enqueues` is `True`, all pending requests will also
557    be canceled.
558
559    Args:
560      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
561        `False` (described above).
562      name: A name for the operation (optional).
563
564    Returns:
565      The operation that closes the queue.
566    """
567    if name is None:
568      name = "%s_Close" % self._name
569    if self._queue_ref.dtype == _dtypes.resource:
570      return gen_data_flow_ops.queue_close_v2(
571          self._queue_ref,
572          cancel_pending_enqueues=cancel_pending_enqueues,
573          name=name)
574    else:
575      return gen_data_flow_ops.queue_close(
576          self._queue_ref,
577          cancel_pending_enqueues=cancel_pending_enqueues,
578          name=name)
579
580  def is_closed(self, name=None):
581    """Returns true if queue is closed.
582
583    This operation returns true if the queue is closed and false if the queue
584    is open.
585
586    Args:
587      name: A name for the operation (optional).
588
589    Returns:
590      True if the queue is closed and false if the queue is open.
591    """
592    if name is None:
593      name = "%s_Is_Closed" % self._name
594    if self._queue_ref.dtype == _dtypes.resource:
595      return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
596    else:
597      return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
598
599  def size(self, name=None):
600    """Compute the number of elements in this queue.
601
602    Args:
603      name: A name for the operation (optional).
604
605    Returns:
606      A scalar tensor containing the number of elements in this queue.
607    """
608    if name is None:
609      name = "%s_Size" % self._name
610    if self._queue_ref.dtype == _dtypes.resource:
611      return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name)
612    else:
613      return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
614
615def _shared_name(shared_name):
616  if context.executing_eagerly():
617    return str(ops.uid())
618  return shared_name
619
620
621@tf_export(
622    "queue.RandomShuffleQueue",
623    v1=["queue.RandomShuffleQueue",
624        "io.RandomShuffleQueue", "RandomShuffleQueue"])
625@deprecation.deprecated_endpoints(
626    ["io.RandomShuffleQueue", "RandomShuffleQueue"])
627class RandomShuffleQueue(QueueBase):
628  """A queue implementation that dequeues elements in a random order.
629
630  See `tf.queue.QueueBase` for a description of the methods on
631  this class.
632  """
633
634  def __init__(self,
635               capacity,
636               min_after_dequeue,
637               dtypes,
638               shapes=None,
639               names=None,
640               seed=None,
641               shared_name=None,
642               name="random_shuffle_queue"):
643    """Create a queue that dequeues elements in a random order.
644
645    A `RandomShuffleQueue` has bounded capacity; supports multiple
646    concurrent producers and consumers; and provides exactly-once
647    delivery.
648
649    A `RandomShuffleQueue` holds a list of up to `capacity`
650    elements. Each element is a fixed-length tuple of tensors whose
651    dtypes are described by `dtypes`, and whose shapes are optionally
652    described by the `shapes` argument.
653
654    If the `shapes` argument is specified, each component of a queue
655    element must have the respective fixed shape. If it is
656    unspecified, different queue elements may have different shapes,
657    but the use of `dequeue_many` is disallowed.
658
659    The `min_after_dequeue` argument allows the caller to specify a
660    minimum number of elements that will remain in the queue after a
661    `dequeue` or `dequeue_many` operation completes, to ensure a
662    minimum level of mixing of elements. This invariant is maintained
663    by blocking those operations until sufficient elements have been
664    enqueued. The `min_after_dequeue` argument is ignored after the
665    queue has been closed.
666
667    Args:
668      capacity: An integer. The upper bound on the number of elements
669        that may be stored in this queue.
670      min_after_dequeue: An integer (described above).
671      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
672        the number of tensors in each queue element.
673      shapes: (Optional.) A list of fully-defined `TensorShape` objects
674        with the same length as `dtypes`, or `None`.
675      names: (Optional.) A list of string naming the components in the queue
676        with the same length as `dtypes`, or `None`.  If specified the dequeue
677        methods return a dictionary with the names as keys.
678      seed: A Python integer. Used to create a random seed. See
679        `tf.compat.v1.set_random_seed`
680        for behavior.
681      shared_name: (Optional.) If non-empty, this queue will be shared under
682        the given name across multiple sessions.
683      name: Optional name for the queue operation.
684    """
685    dtypes = _as_type_list(dtypes)
686    shapes = _as_shape_list(shapes, dtypes)
687    names = _as_name_list(names, dtypes)
688    seed1, seed2 = random_seed.get_seed(seed)
689    if seed1 is None and seed2 is None:
690      seed1, seed2 = 0, 0
691    elif seed is None and shared_name is not None:
692      # This means that graph seed is provided but op seed is not provided.
693      # If shared_name is also provided, make seed2 depend only on the graph
694      # seed and shared_name. (seed2 from get_seed() is generally dependent on
695      # the id of the last op created.)
696      string = (str(seed1) + shared_name).encode("utf-8")
697      seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
698    queue_ref = gen_data_flow_ops.random_shuffle_queue_v2(
699        component_types=dtypes,
700        shapes=shapes,
701        capacity=capacity,
702        min_after_dequeue=min_after_dequeue,
703        seed=seed1,
704        seed2=seed2,
705        shared_name=_shared_name(shared_name),
706        name=name)
707
708    super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
709
710
711@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"])
712@deprecation.deprecated_endpoints("FIFOQueue")
713class FIFOQueue(QueueBase):
714  """A queue implementation that dequeues elements in first-in first-out order.
715
716  See `tf.queue.QueueBase` for a description of the methods on
717  this class.
718  """
719
720  def __init__(self,
721               capacity,
722               dtypes,
723               shapes=None,
724               names=None,
725               shared_name=None,
726               name="fifo_queue"):
727    """Creates a queue that dequeues elements in a first-in first-out order.
728
729    A `FIFOQueue` has bounded capacity; supports multiple concurrent
730    producers and consumers; and provides exactly-once delivery.
731
732    A `FIFOQueue` holds a list of up to `capacity` elements. Each
733    element is a fixed-length tuple of tensors whose dtypes are
734    described by `dtypes`, and whose shapes are optionally described
735    by the `shapes` argument.
736
737    If the `shapes` argument is specified, each component of a queue
738    element must have the respective fixed shape. If it is
739    unspecified, different queue elements may have different shapes,
740    but the use of `dequeue_many` is disallowed.
741
742    Args:
743      capacity: An integer. The upper bound on the number of elements
744        that may be stored in this queue.
745      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
746        the number of tensors in each queue element.
747      shapes: (Optional.) A list of fully-defined `TensorShape` objects
748        with the same length as `dtypes`, or `None`.
749      names: (Optional.) A list of string naming the components in the queue
750        with the same length as `dtypes`, or `None`.  If specified the dequeue
751        methods return a dictionary with the names as keys.
752      shared_name: (Optional.) If non-empty, this queue will be shared under
753        the given name across multiple sessions.
754      name: Optional name for the queue operation.
755    """
756    dtypes = _as_type_list(dtypes)
757    shapes = _as_shape_list(shapes, dtypes)
758    names = _as_name_list(names, dtypes)
759    with ops.init_scope(), ops.device("CPU"):
760      queue_ref = gen_data_flow_ops.fifo_queue_v2(
761          component_types=dtypes,
762          shapes=shapes,
763          capacity=capacity,
764          shared_name=_shared_name(shared_name),
765          name=name)
766
767    super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
768
769
770# TODO(allenl): If GPU-compatible queues turn out to be useful, we should
771# implement GPU kernels for EnqueueMany and DequeueMany so we can make the
772# public FIFOQueue GPU-compatible and remove this internal version.
773class GPUCompatibleFIFOQueue(QueueBase):
774  """A queue implementation that dequeues elements in first-in first-out order.
775
776  GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed
777  either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues
778  will be colocated with the queue resource. GPUCompatibleFIFOQueue only
779  supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many.
780
781  See `tf.queue.QueueBase` for a description of the methods on this class.
782  """
783
784  def __init__(self,
785               capacity,
786               dtypes,
787               shapes=None,
788               names=None,
789               shared_name=None,
790               name="fifo_queue"):
791    """Creates a queue that dequeues elements in a first-in first-out order.
792
793    A `FIFOQueue` has bounded capacity; supports multiple concurrent
794    producers and consumers; and provides exactly-once delivery.
795
796    A `FIFOQueue` holds a list of up to `capacity` elements. Each
797    element is a fixed-length tuple of tensors whose dtypes are
798    described by `dtypes`, and whose shapes are optionally described
799    by the `shapes` argument.
800
801    If the `shapes` argument is specified, each component of a queue
802    element must have the respective fixed shape. If it is
803    unspecified, different queue elements may have different shapes,
804    but the use of `dequeue_many` is disallowed.
805
806    Args:
807      capacity: An integer. The upper bound on the number of elements
808        that may be stored in this queue.
809      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
810        the number of tensors in each queue element.
811      shapes: (Optional.) A list of fully-defined `TensorShape` objects
812        with the same length as `dtypes`, or `None`.
813      names: (Optional.) A list of string naming the components in the queue
814        with the same length as `dtypes`, or `None`.  If specified the dequeue
815        methods return a dictionary with the names as keys.
816      shared_name: (Optional.) If non-empty, this queue will be shared under
817        the given name across multiple sessions.
818      name: Optional name for the queue operation.
819    """
820    dtypes = _as_type_list(dtypes)
821    shapes = _as_shape_list(shapes, dtypes)
822    names = _as_name_list(names, dtypes)
823    with ops.init_scope():
824      queue_ref = gen_data_flow_ops.fifo_queue_v2(
825          component_types=dtypes,
826          shapes=shapes,
827          capacity=capacity,
828          shared_name=_shared_name(shared_name),
829          name=name)
830
831    super(GPUCompatibleFIFOQueue, self).__init__(
832        dtypes, shapes, names, queue_ref)
833
834  def enqueue_many(self, vals, name=None):
835    """enqueue_many is not supported on GPUCompatibleFIFOQueue."""
836    raise NotImplementedError(
837        "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
838        "only enqueue and dequeue.")
839
840  def dequeue_many(self, n, name=None):
841    """dequeue_many is not supported on GPUCompatibleFIFOQueue."""
842    raise NotImplementedError(
843        "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
844        "only enqueue and dequeue.")
845
846
847@tf_export(
848    "queue.PaddingFIFOQueue",
849    v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"])
850@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"])
851class PaddingFIFOQueue(QueueBase):
852  """A FIFOQueue that supports batching variable-sized tensors by padding.
853
854  A `PaddingFIFOQueue` may contain components with dynamic shape, while also
855  supporting `dequeue_many`.  See the constructor for more details.
856
857  See `tf.queue.QueueBase` for a description of the methods on
858  this class.
859  """
860
861  def __init__(self,
862               capacity,
863               dtypes,
864               shapes,
865               names=None,
866               shared_name=None,
867               name="padding_fifo_queue"):
868    """Creates a queue that dequeues elements in a first-in first-out order.
869
870    A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
871    producers and consumers; and provides exactly-once delivery.
872
873    A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
874    element is a fixed-length tuple of tensors whose dtypes are
875    described by `dtypes`, and whose shapes are described by the `shapes`
876    argument.
877
878    The `shapes` argument must be specified; each component of a queue
879    element must have the respective shape.  Shapes of fixed
880    rank but variable size are allowed by setting any shape dimension to None.
881    In this case, the inputs' shape may vary along the given dimension, and
882    `dequeue_many` will pad the given dimension with zeros up to the maximum
883    shape of all elements in the given batch.
884
885    Args:
886      capacity: An integer. The upper bound on the number of elements
887        that may be stored in this queue.
888      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
889        the number of tensors in each queue element.
890      shapes: A list of `TensorShape` objects, with the same length as
891        `dtypes`.  Any dimension in the `TensorShape` containing value
892        `None` is dynamic and allows values to be enqueued with
893         variable size in that dimension.
894      names: (Optional.) A list of string naming the components in the queue
895        with the same length as `dtypes`, or `None`.  If specified the dequeue
896        methods return a dictionary with the names as keys.
897      shared_name: (Optional.) If non-empty, this queue will be shared under
898        the given name across multiple sessions.
899      name: Optional name for the queue operation.
900
901    Raises:
902      ValueError: If shapes is not a list of shapes, or the lengths of dtypes
903        and shapes do not match, or if names is specified and the lengths of
904        dtypes and names do not match.
905    """
906    dtypes = _as_type_list(dtypes)
907    shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
908    names = _as_name_list(names, dtypes)
909    if len(dtypes) != len(shapes):
910      raise ValueError("Shapes must be provided for all components, "
911                       f"but received {len(dtypes)} dtypes and "
912                       f"{len(shapes)} shapes.")
913    queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
914        component_types=dtypes,
915        shapes=shapes,
916        capacity=capacity,
917        shared_name=_shared_name(shared_name),
918        name=name)
919
920    super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
921
922
923@tf_export("queue.PriorityQueue",
924           v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"])
925@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"])
926class PriorityQueue(QueueBase):
927  """A queue implementation that dequeues elements in prioritized order.
928
929  See `tf.queue.QueueBase` for a description of the methods on
930  this class.
931  """
932
933  def __init__(self,
934               capacity,
935               types,
936               shapes=None,
937               names=None,
938               shared_name=None,
939               name="priority_queue"):
940    """Creates a queue that dequeues elements in a first-in first-out order.
941
942    A `PriorityQueue` has bounded capacity; supports multiple concurrent
943    producers and consumers; and provides exactly-once delivery.
944
945    A `PriorityQueue` holds a list of up to `capacity` elements. Each
946    element is a fixed-length tuple of tensors whose dtypes are
947    described by `types`, and whose shapes are optionally described
948    by the `shapes` argument.
949
950    If the `shapes` argument is specified, each component of a queue
951    element must have the respective fixed shape. If it is
952    unspecified, different queue elements may have different shapes,
953    but the use of `dequeue_many` is disallowed.
954
955    Enqueues and Dequeues to the `PriorityQueue` must include an additional
956    tuple entry at the beginning: the `priority`.  The priority must be
957    an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
958
959    Args:
960      capacity: An integer. The upper bound on the number of elements
961        that may be stored in this queue.
962      types:  A list of `DType` objects. The length of `types` must equal
963        the number of tensors in each queue element, except the first priority
964        element.  The first tensor in each element is the priority,
965        which must be type int64.
966      shapes: (Optional.) A list of fully-defined `TensorShape` objects,
967        with the same length as `types`, or `None`.
968      names: (Optional.) A list of strings naming the components in the queue
969        with the same length as `dtypes`, or `None`.  If specified, the dequeue
970        methods return a dictionary with the names as keys.
971      shared_name: (Optional.) If non-empty, this queue will be shared under
972        the given name across multiple sessions.
973      name: Optional name for the queue operation.
974    """
975    types = _as_type_list(types)
976    shapes = _as_shape_list(shapes, types)
977
978    queue_ref = gen_data_flow_ops.priority_queue_v2(
979        component_types=types,
980        shapes=shapes,
981        capacity=capacity,
982        shared_name=_shared_name(shared_name),
983        name=name)
984
985    priority_dtypes = [_dtypes.int64] + types
986    priority_shapes = [()] + shapes if shapes else shapes
987
988    super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
989                                        queue_ref)
990
991
992# TODO(josh11b): class BatchQueue(QueueBase):
993
994
995class Barrier:
996  """Represents a key-value map that persists across graph executions."""
997
998  def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
999    """Creates a barrier that persists across different graph executions.
1000
1001    A barrier represents a key-value map, where each key is a string, and
1002    each value is a tuple of tensors.
1003
1004    At runtime, the barrier contains 'complete' and 'incomplete'
1005    elements. A complete element has defined tensors for all
1006    components of its value tuple, and may be accessed using
1007    take_many. An incomplete element has some undefined components in
1008    its value tuple, and may be updated using insert_many.
1009
1010    The barrier call `take_many` outputs values in a particular order.
1011    First, it only outputs completed values.  Second, the order in which
1012    completed values are returned matches the order in which their very
1013    first component was inserted into the barrier.  So, for example, for this
1014    sequence of insertions and removals:
1015
1016      barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
1017      barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
1018      barrier.insert_many(1, keys=["k1"], values=[1]).run()
1019      barrier.insert_many(0, keys=["k3"], values=["c"]).run()
1020      barrier.insert_many(1, keys=["k3"], values=[3]).run()
1021      barrier.insert_many(1, keys=["k2"], values=[2]).run()
1022
1023      (indices, keys, values) = barrier.take_many(2)
1024      (indices_val, keys_val, values0_val, values1_val) =
1025         session.run([indices, keys, values[0], values[1]])
1026
1027    The output will be (up to permutation of "k1" and "k2"):
1028
1029      indices_val == (-2**63, -2**63)
1030      keys_val == ("k1", "k2")
1031      values0_val == ("a", "b")
1032      values1_val == (1, 2)
1033
1034    Note the key "k2" was inserted into the barrier before "k3".  Even though
1035    "k3" was completed first, both are complete by the time
1036    take_many is called.  As a result, "k2" is prioritized and "k1" and "k2"
1037    are returned first.  "k3" remains in the barrier until the next execution
1038    of `take_many`.  Since "k1" and "k2" had their first insertions into
1039    the barrier together, their indices are the same (-2**63).  The index
1040    of "k3" will be -2**63 + 1, because it was the next new inserted key.
1041
1042    Args:
1043      types: A single dtype or a tuple of dtypes, corresponding to the
1044        dtypes of the tensor elements that comprise a value in this barrier.
1045      shapes: Optional. Constraints on the shapes of tensors in the values:
1046        a single tensor shape tuple; a tuple of tensor shape tuples
1047        for each barrier-element tuple component; or None if the shape should
1048        not be constrained.
1049      shared_name: Optional. If non-empty, this barrier will be shared under
1050        the given name across multiple sessions.
1051      name: Optional name for the barrier op.
1052
1053    Raises:
1054      ValueError: If one of the `shapes` indicate no elements.
1055    """
1056    self._types = _as_type_list(types)
1057
1058    if shapes is not None:
1059      shapes = _as_shape_list(shapes, self._types)
1060      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1061      for i, shape in enumerate(self._shapes):
1062        if shape.num_elements() == 0:
1063          raise ValueError("Empty tensors are not supported, but received "
1064                           f"shape '{shape}' at index {i}")
1065    else:
1066      self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
1067
1068    self._barrier_ref = gen_data_flow_ops.barrier(
1069        component_types=self._types,
1070        shapes=self._shapes,
1071        shared_name=shared_name,
1072        name=name)
1073    if context.executing_eagerly():
1074      self._name = context.context().scope_name
1075    else:
1076      self._name = self._barrier_ref.op.name.split("/")[-1]
1077
1078  @property
1079  def barrier_ref(self):
1080    """Get the underlying barrier reference."""
1081    return self._barrier_ref
1082
1083  @property
1084  def name(self):
1085    """The name of the underlying barrier."""
1086    if context.executing_eagerly():
1087      return self._name
1088    return self._barrier_ref.op.name
1089
1090  def insert_many(self, component_index, keys, values, name=None):
1091    """For each key, assigns the respective value to the specified component.
1092
1093    This operation updates each element at component_index.
1094
1095    Args:
1096      component_index: The component of the value that is being assigned.
1097      keys: A vector of keys, with length n.
1098      values: An any-dimensional tensor of values, which are associated with the
1099        respective keys. The first dimension must have length n.
1100      name: Optional name for the op.
1101
1102    Returns:
1103      The operation that performs the insertion.
1104    Raises:
1105      InvalidArgumentsError: If inserting keys and values without elements.
1106    """
1107    if name is None:
1108      name = "%s_BarrierInsertMany" % self._name
1109    return gen_data_flow_ops.barrier_insert_many(
1110        self._barrier_ref, keys, values, component_index, name=name)
1111
1112  def take_many(self,
1113                num_elements,
1114                allow_small_batch=False,
1115                timeout=None,
1116                name=None):
1117    """Takes the given number of completed elements from this barrier.
1118
1119    This operation concatenates completed-element component tensors along
1120    the 0th dimension to make a single component tensor.
1121
1122    If barrier has no completed elements, this operation will block
1123    until there are 'num_elements' elements to take.
1124
1125    TODO(b/25743580): the semantics of `allow_small_batch` are experimental
1126    and may be extended to other cases in the future.
1127
1128    TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
1129    already when the barrier is closed, it will block for ever. Fix this
1130    by using asynchronous operations.
1131
1132    Args:
1133      num_elements: The number of elements to take.
1134      allow_small_batch: If the barrier is closed, don't block if there are less
1135        completed elements than requested, but instead return all available
1136        completed elements.
1137      timeout: This specifies the number of milliseconds to block
1138        before returning with DEADLINE_EXCEEDED. (This option is not
1139        supported yet.)
1140      name: A name for the operation (optional).
1141
1142    Returns:
1143      A tuple of (index, key, value_list).
1144      "index" is a int64 tensor of length num_elements containing the
1145        index of the insert_many call for which the very first component of
1146        the given element was inserted into the Barrier, starting with
1147        the value -2**63.  Note, this value is different from the
1148        index of the insert_many call for which the element was completed.
1149      "key" is a string tensor of length num_elements containing the keys.
1150      "value_list" is a tuple of tensors, each one with size num_elements
1151        in the 0th dimension for each component in the barrier's values.
1152
1153    """
1154    if name is None:
1155      name = "%s_BarrierTakeMany" % self._name
1156    ret = gen_data_flow_ops.barrier_take_many(
1157        self._barrier_ref,
1158        num_elements,
1159        self._types,
1160        allow_small_batch,
1161        timeout,
1162        name=name)
1163
1164    # NOTE(mrry): Not using a shape function because we need access to
1165    # the Barrier object.
1166    if not context.executing_eagerly():
1167      op = ret[0].op
1168      if allow_small_batch:
1169        batch_dim = None
1170      else:
1171        batch_dim = tensor_shape.Dimension(
1172            tensor_util.constant_value(op.inputs[1]))
1173      op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim]))  # indices
1174      op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim]))  # keys
1175      for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
1176        output.set_shape(
1177            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1178
1179    return ret
1180
1181  def close(self, cancel_pending_enqueues=False, name=None):
1182    """Closes this barrier.
1183
1184    This operation signals that no more new key values will be inserted in the
1185    given barrier. Subsequent InsertMany operations with new keys will fail.
1186    InsertMany operations that just complement already existing keys with other
1187    components, will continue to succeed. Subsequent TakeMany operations will
1188    continue to succeed if sufficient elements remain in the barrier. Subsequent
1189    TakeMany operations that would block will fail immediately.
1190
1191    If `cancel_pending_enqueues` is `True`, all pending requests to the
1192    underlying queue will also be canceled, and completing of already
1193    started values is also not acceptable anymore.
1194
1195    Args:
1196      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
1197        `False` (described above).
1198      name: Optional name for the op.
1199
1200    Returns:
1201      The operation that closes the barrier.
1202    """
1203    if name is None:
1204      name = "%s_BarrierClose" % self._name
1205    return gen_data_flow_ops.barrier_close(
1206        self._barrier_ref,
1207        cancel_pending_enqueues=cancel_pending_enqueues,
1208        name=name)
1209
1210  def ready_size(self, name=None):
1211    """Compute the number of complete elements in the given barrier.
1212
1213    Args:
1214      name: A name for the operation (optional).
1215
1216    Returns:
1217      A single-element tensor containing the number of complete elements in the
1218      given barrier.
1219    """
1220    if name is None:
1221      name = "%s_BarrierReadySize" % self._name
1222    return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name)
1223
1224  def incomplete_size(self, name=None):
1225    """Compute the number of incomplete elements in the given barrier.
1226
1227    Args:
1228      name: A name for the operation (optional).
1229
1230    Returns:
1231      A single-element tensor containing the number of incomplete elements in
1232      the given barrier.
1233    """
1234    if name is None:
1235      name = "%s_BarrierIncompleteSize" % self._name
1236    return gen_data_flow_ops.barrier_incomplete_size(
1237        self._barrier_ref, name=name)
1238
1239
1240@tf_export(v1=["ConditionalAccumulatorBase"])
1241class ConditionalAccumulatorBase:
1242  """A conditional accumulator for aggregating gradients.
1243
1244  Up-to-date gradients (i.e., time step at which gradient was computed is
1245  equal to the accumulator's time step) are added to the accumulator.
1246
1247  Extraction of the average gradient is blocked until the required number of
1248  gradients has been accumulated.
1249  """
1250
1251  def __init__(self, dtype, shape, accumulator_ref):
1252    """Creates a new ConditionalAccumulator.
1253
1254    Args:
1255      dtype: Datatype of the accumulated gradients.
1256      shape: Shape of the accumulated gradients.
1257      accumulator_ref: A handle to the conditional accumulator, created by sub-
1258        classes
1259    """
1260    self._dtype = dtype
1261    if shape is not None:
1262      self._shape = tensor_shape.TensorShape(shape)
1263    else:
1264      self._shape = tensor_shape.unknown_shape()
1265    self._accumulator_ref = accumulator_ref
1266    if context.executing_eagerly():
1267      self._name = context.context().scope_name
1268    else:
1269      self._name = self._accumulator_ref.op.name.split("/")[-1]
1270
1271  @property
1272  def accumulator_ref(self):
1273    """The underlying accumulator reference."""
1274    return self._accumulator_ref
1275
1276  @property
1277  def name(self):
1278    """The name of the underlying accumulator."""
1279    return self._name
1280
1281  @property
1282  def dtype(self):
1283    """The datatype of the gradients accumulated by this accumulator."""
1284    return self._dtype
1285
1286  def num_accumulated(self, name=None):
1287    """Number of gradients that have currently been aggregated in accumulator.
1288
1289    Args:
1290      name: Optional name for the operation.
1291
1292    Returns:
1293      Number of accumulated gradients currently in accumulator.
1294    """
1295    if name is None:
1296      name = "%s_NumAccumulated" % self._name
1297
1298    return gen_data_flow_ops.resource_accumulator_num_accumulated(
1299        self._accumulator_ref, name=name)
1300
1301  def set_global_step(self, new_global_step, name=None):
1302    """Sets the global time step of the accumulator.
1303
1304    The operation logs a warning if we attempt to set to a time step that is
1305    lower than the accumulator's own time step.
1306
1307    Args:
1308      new_global_step: Value of new time step. Can be a variable or a constant
1309      name: Optional name for the operation.
1310
1311    Returns:
1312      Operation that sets the accumulator's time step.
1313    """
1314    return gen_data_flow_ops.resource_accumulator_set_global_step(
1315        self._accumulator_ref,
1316        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1317        name=name)
1318
1319
1320@tf_export(v1=["ConditionalAccumulator"])
1321class ConditionalAccumulator(ConditionalAccumulatorBase):
1322  """A conditional accumulator for aggregating gradients.
1323
1324  Up-to-date gradients (i.e., time step at which gradient was computed is
1325  equal to the accumulator's time step) are added to the accumulator.
1326
1327  Extraction of the average gradient is blocked until the required number of
1328  gradients has been accumulated.
1329  """
1330
1331  def __init__(self,
1332               dtype,
1333               shape=None,
1334               shared_name=None,
1335               name="conditional_accumulator",
1336               reduction_type="MEAN"):
1337    """Creates a new ConditionalAccumulator.
1338
1339    Args:
1340      dtype: Datatype of the accumulated gradients.
1341      shape: Shape of the accumulated gradients.
1342      shared_name: Optional. If non-empty, this accumulator will be shared under
1343        the given name across multiple sessions.
1344      name: Optional name for the accumulator.
1345      reduction_type: Reduction type to use when taking the gradient.
1346    """
1347    accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
1348        dtype=dtype,
1349        shape=shape,
1350        shared_name=shared_name,
1351        name=name,
1352        reduction_type=reduction_type)
1353    if context.executing_eagerly():
1354      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
1355          handle=accumulator_ref, handle_device=context.context().device_name)
1356
1357    super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
1358
1359  def apply_grad(self, grad, local_step=0, name=None):
1360    """Attempts to apply a gradient to the accumulator.
1361
1362    The attempt is silently dropped if the gradient is stale, i.e., local_step
1363    is less than the accumulator's global time step.
1364
1365    Args:
1366      grad: The gradient tensor to be applied.
1367      local_step: Time step at which the gradient was computed.
1368      name: Optional name for the operation.
1369
1370    Returns:
1371      The operation that (conditionally) applies a gradient to the accumulator.
1372
1373    Raises:
1374      ValueError: If grad is of the wrong shape
1375    """
1376    grad = ops.convert_to_tensor(grad, self._dtype)
1377    grad.get_shape().assert_is_compatible_with(self._shape)
1378    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1379
1380    return gen_data_flow_ops.resource_accumulator_apply_gradient(
1381        self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
1382
1383  def take_grad(self, num_required, name=None):
1384    """Attempts to extract the average gradient from the accumulator.
1385
1386    The operation blocks until sufficient number of gradients have been
1387    successfully applied to the accumulator.
1388
1389    Once successful, the following actions are also triggered:
1390
1391    - Counter of accumulated gradients is reset to 0.
1392    - Aggregated gradient is reset to 0 tensor.
1393    - Accumulator's internal time step is incremented by 1.
1394
1395    Args:
1396      num_required: Number of gradients that needs to have been aggregated
1397      name: Optional name for the operation
1398
1399    Returns:
1400      A tensor holding the value of the average gradient.
1401
1402    Raises:
1403      InvalidArgumentError: If num_required < 1
1404    """
1405    out = gen_data_flow_ops.resource_accumulator_take_gradient(
1406        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1407    out.set_shape(self._shape)
1408    return out
1409
1410
1411@tf_export(
1412    v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"])
1413class SparseConditionalAccumulator(ConditionalAccumulatorBase):
1414  """A conditional accumulator for aggregating sparse gradients.
1415
1416  Sparse gradients are represented by `IndexedSlices`.
1417
1418  Up-to-date gradients (i.e., time step at which gradient was computed is
1419  equal to the accumulator's time step) are added to the accumulator.
1420
1421  Extraction of the average gradient is blocked until the required number of
1422  gradients has been accumulated.
1423
1424  Args:
1425    dtype: Datatype of the accumulated gradients.
1426    shape: Shape of the accumulated gradients.
1427    shared_name: Optional. If non-empty, this accumulator will be shared under
1428      the given name across multiple sessions.
1429    name: Optional name for the accumulator.
1430    reduction_type: Reduction type to use when taking the gradient.
1431  """
1432
1433  def __init__(self,
1434               dtype,
1435               shape=None,
1436               shared_name=None,
1437               name="sparse_conditional_accumulator",
1438               reduction_type="MEAN"):
1439    accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
1440        dtype=dtype,
1441        shape=shape,
1442        shared_name=shared_name,
1443        name=name,
1444        reduction_type=reduction_type)
1445    super(SparseConditionalAccumulator, self).__init__(dtype, shape,
1446                                                       accumulator_ref)
1447
1448  def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
1449    """Attempts to apply a gradient to the accumulator.
1450
1451    The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1452    is less than the accumulator's global time step.
1453
1454    Args:
1455      grad: The gradient `IndexedSlices` to be applied.
1456      local_step: Time step at which the gradient was computed.
1457      name: Optional name for the operation.
1458
1459    Returns:
1460      The operation that (conditionally) applies a gradient to the accumulator.
1461
1462    Raises:
1463      InvalidArgumentError: If grad is of the wrong shape
1464    """
1465    return self.apply_grad(
1466        grad_indices=grad.indices,
1467        grad_values=grad.values,
1468        grad_shape=grad.dense_shape,
1469        local_step=local_step,
1470        name=name)
1471
1472  def apply_grad(self,
1473                 grad_indices,
1474                 grad_values,
1475                 grad_shape=None,
1476                 local_step=0,
1477                 name=None):
1478    """Attempts to apply a sparse gradient to the accumulator.
1479
1480    The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1481    is less than the accumulator's global time step.
1482
1483    A sparse gradient is represented by its indices, values and possibly empty
1484    or None shape. Indices must be a vector representing the locations of
1485    non-zero entries in the tensor. Values are the non-zero slices of the
1486    gradient, and must have the same first dimension as indices, i.e., the nnz
1487    represented by indices and values must be consistent. Shape, if not empty or
1488    None, must be consistent with the accumulator's shape (if also provided).
1489
1490    Example:
1491      A tensor [[0, 0], [0, 1], [2, 3]] can be represented
1492        indices: [1,2]
1493        values: [[0,1],[2,3]]
1494        shape: [3, 2]
1495
1496    Args:
1497      grad_indices: Indices of the sparse gradient to be applied.
1498      grad_values: Values of the sparse gradient to be applied.
1499      grad_shape: Shape of the sparse gradient to be applied.
1500      local_step: Time step at which the gradient was computed.
1501      name: Optional name for the operation.
1502
1503    Returns:
1504      The operation that (conditionally) applies a gradient to the accumulator.
1505
1506    Raises:
1507      InvalidArgumentError: If grad is of the wrong shape
1508    """
1509    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1510    return gen_data_flow_ops.sparse_accumulator_apply_gradient(
1511        self._accumulator_ref,
1512        local_step=local_step,
1513        gradient_indices=math_ops.cast(grad_indices, _dtypes.int64),
1514        gradient_values=grad_values,
1515        gradient_shape=math_ops.cast(
1516            [] if grad_shape is None else grad_shape, _dtypes.int64),
1517        has_known_shape=(grad_shape is not None),
1518        name=name)
1519
1520  def take_grad(self, num_required, name=None):
1521    """Attempts to extract the average gradient from the accumulator.
1522
1523    The operation blocks until sufficient number of gradients have been
1524    successfully applied to the accumulator.
1525
1526    Once successful, the following actions are also triggered:
1527    - Counter of accumulated gradients is reset to 0.
1528    - Aggregated gradient is reset to 0 tensor.
1529    - Accumulator's internal time step is incremented by 1.
1530
1531    Args:
1532      num_required: Number of gradients that needs to have been aggregated
1533      name: Optional name for the operation
1534
1535    Returns:
1536      A tuple of indices, values, and shape representing the average gradient.
1537
1538    Raises:
1539      InvalidArgumentError: If `num_required` < 1
1540    """
1541    return gen_data_flow_ops.sparse_accumulator_take_gradient(
1542        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1543
1544  def take_indexed_slices_grad(self, num_required, name=None):
1545    """Attempts to extract the average gradient from the accumulator.
1546
1547    The operation blocks until sufficient number of gradients have been
1548    successfully applied to the accumulator.
1549
1550    Once successful, the following actions are also triggered:
1551    - Counter of accumulated gradients is reset to 0.
1552    - Aggregated gradient is reset to 0 tensor.
1553    - Accumulator's internal time step is incremented by 1.
1554
1555    Args:
1556      num_required: Number of gradients that needs to have been aggregated
1557      name: Optional name for the operation
1558
1559    Returns:
1560      An `IndexedSlices` holding the value of the average gradient.
1561
1562    Raises:
1563      InvalidArgumentError: If `num_required` < 1
1564    """
1565    return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
1566        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1567    return indexed_slices.IndexedSlices(
1568        indices=return_val.indices,
1569        values=return_val.values,
1570        dense_shape=return_val.shape)
1571
1572  # SparseConditionalAccumulator is not switched to resource. Use old kernels.
1573  def num_accumulated(self, name=None):
1574    """Number of gradients that have currently been aggregated in accumulator.
1575
1576    Args:
1577      name: Optional name for the operation.
1578
1579    Returns:
1580      Number of accumulated gradients currently in accumulator.
1581    """
1582    if name is None:
1583      name = "%s_NumAccumulated" % self._name
1584
1585    return gen_data_flow_ops.accumulator_num_accumulated(
1586        self._accumulator_ref, name=name)
1587
1588  def set_global_step(self, new_global_step, name=None):
1589    """Sets the global time step of the accumulator.
1590
1591    The operation logs a warning if we attempt to set to a time step that is
1592    lower than the accumulator's own time step.
1593
1594    Args:
1595      new_global_step: Value of new time step. Can be a variable or a constant
1596      name: Optional name for the operation.
1597
1598    Returns:
1599      Operation that sets the accumulator's time step.
1600    """
1601    return gen_data_flow_ops.accumulator_set_global_step(
1602        self._accumulator_ref,
1603        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1604        name=name)
1605
1606
1607class BaseStagingArea:
1608  """Base class for Staging Areas."""
1609  _identifier = 0
1610  _lock = threading.Lock()
1611
1612  def __init__(self,
1613               dtypes,
1614               shapes=None,
1615               names=None,
1616               shared_name=None,
1617               capacity=0,
1618               memory_limit=0):
1619    if shared_name is None:
1620      self._name = (
1621          ops.get_default_graph().unique_name(self.__class__.__name__))
1622    elif isinstance(shared_name, str):
1623      self._name = shared_name
1624    else:
1625      raise ValueError(f"shared_name must be a string, got {shared_name}")
1626
1627    self._dtypes = dtypes
1628
1629    if shapes is not None:
1630      if len(shapes) != len(dtypes):
1631        raise ValueError("StagingArea shapes must be the same length as dtypes")
1632      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1633    else:
1634      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
1635
1636    if names is not None:
1637      if len(names) != len(dtypes):
1638        raise ValueError("StagingArea names must be the same length as dtypes")
1639      self._names = names
1640    else:
1641      self._names = None
1642
1643    self._capacity = capacity
1644    self._memory_limit = memory_limit
1645
1646    # all get and put ops must colocate with this op
1647    with ops.name_scope("%s_root" % self._name):
1648      self._coloc_op = control_flow_ops.no_op()
1649
1650  @property
1651  def name(self):
1652    """The name of the staging area."""
1653    return self._name
1654
1655  @property
1656  def dtypes(self):
1657    """The list of dtypes for each component of a staging area element."""
1658    return self._dtypes
1659
1660  @property
1661  def shapes(self):
1662    """The list of shapes for each component of a staging area element."""
1663    return self._shapes
1664
1665  @property
1666  def names(self):
1667    """The list of names for each component of a staging area element."""
1668    return self._names
1669
1670  @property
1671  def capacity(self):
1672    """The maximum number of elements of this staging area."""
1673    return self._capacity
1674
1675  @property
1676  def memory_limit(self):
1677    """The maximum number of bytes of this staging area."""
1678    return self._memory_limit
1679
1680  def _check_put_dtypes(self, vals, indices=None):
1681    """Validate and convert `vals` to a list of `Tensor`s.
1682
1683    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
1684    dictionary with tensor values.
1685
1686    If `vals` is a list, then the appropriate indices associated with the
1687    values must be provided.
1688
1689    If it is a dictionary, the staging area must have been constructed with a
1690    `names` attribute and the dictionary keys must match the staging area names.
1691    `indices` will be inferred from the dictionary keys.
1692    If the staging area was constructed with a `names` attribute, `vals` must
1693    be a dictionary.
1694
1695    Checks that the dtype and shape of each value matches that
1696    of the staging area.
1697
1698    Args:
1699      vals: A tensor, a list or tuple of tensors, or a dictionary.
1700
1701    Returns:
1702      A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
1703      and `indices` is a list of indices associated with the tensors.
1704
1705    Raises:
1706      ValueError: If `vals` or `indices` is invalid.
1707    """
1708    if isinstance(vals, dict):
1709      if not self._names:
1710        raise ValueError(
1711            "Staging areas must have names to enqueue a dictionary")
1712      if not set(vals.keys()).issubset(self._names):
1713        raise ValueError("Keys in dictionary to put do not match names "
1714                         f"of staging area. Dictionary: {sorted(vals.keys())}"
1715                         f"Queue: {sorted(self._names)}")
1716      # The order of values in `self._names` indicates the order in which the
1717      # tensors in the dictionary `vals` must be listed.
1718      vals, indices, _ = zip(*[(vals[k], i, k)
1719                               for i, k in enumerate(self._names)
1720                               if k in vals])
1721    else:
1722      if self._names:
1723        raise ValueError("You must enqueue a dictionary in a staging area "
1724                         "with names")
1725
1726      if indices is None:
1727        raise ValueError("Indices must be supplied when inserting a list "
1728                         "of tensors")
1729
1730      if len(indices) != len(vals):
1731        raise ValueError(f"Number of indices {len(indices)} doesn't match "
1732                         f"number of values {len(vals)}")
1733
1734      if not isinstance(vals, (list, tuple)):
1735        vals = [vals]
1736        indices = [0]
1737
1738    # Sanity check number of values
1739    if not len(vals) <= len(self._dtypes):
1740      raise ValueError(f"Unexpected number of inputs {len(vals)} vs "
1741                       f"{len(self._dtypes)}")
1742
1743    tensors = []
1744
1745    for val, i in zip(vals, indices):
1746      dtype, shape = self._dtypes[i], self._shapes[i]
1747      # Check dtype
1748      if val.dtype != dtype:
1749        raise ValueError(f"Datatypes do not match. "
1750                         f"Received val.dtype {str(val.dtype)} and "
1751                         f"dtype {str(dtype)}")
1752      # Check shape
1753      val.get_shape().assert_is_compatible_with(shape)
1754
1755      tensors.append(
1756          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
1757
1758    return tensors, indices
1759
1760  def _create_device_transfers(self, tensors):
1761    """Encode inter-device transfers if the current device
1762    is not the same as the Staging Area's device.
1763    """
1764
1765    if not isinstance(tensors, (tuple, list)):
1766      tensors = [tensors]
1767
1768    curr_device_scope = control_flow_ops.no_op().device
1769
1770    if curr_device_scope != self._coloc_op.device:
1771      tensors = [array_ops.identity(t) for t in tensors]
1772
1773    return tensors
1774
1775  def _get_return_value(self, tensors, indices):
1776    """Return the value to return from a get op.
1777
1778    If the staging area has names, return a dictionary with the
1779    names as keys.  Otherwise return either a single tensor
1780    or a list of tensors depending on the length of `tensors`.
1781
1782    Args:
1783      tensors: List of tensors from the get op.
1784      indices: Indices of associated names and shapes
1785
1786    Returns:
1787      A single tensor, a list of tensors, or a dictionary
1788      of tensors.
1789    """
1790
1791    tensors = self._create_device_transfers(tensors)
1792
1793    # Sets shape
1794    for output, i in zip(tensors, indices):
1795      output.set_shape(self._shapes[i])
1796
1797    if self._names:
1798      # The returned values in `tensors` are in the same order as
1799      # the names in `self._names`.
1800      return {self._names[i]: t for t, i in zip(tensors, indices)}
1801    return tensors
1802
1803  def _scope_vals(self, vals):
1804    """Return a list of values to pass to `name_scope()`.
1805
1806    Args:
1807      vals: A tensor, a list or tuple of tensors, or a dictionary.
1808
1809    Returns:
1810      The values in vals as a list.
1811    """
1812    if isinstance(vals, (list, tuple)):
1813      return vals
1814    elif isinstance(vals, dict):
1815      return vals.values()
1816    else:
1817      return [vals]
1818
1819
1820class StagingArea(BaseStagingArea):
1821  """Class for staging inputs. No ordering guarantees.
1822
1823  A `StagingArea` is a TensorFlow data structure that stores tensors across
1824  multiple steps, and exposes operations that can put and get tensors.
1825
1826  Each `StagingArea` element is a tuple of one or more tensors, where each
1827  tuple component has a static dtype, and may have a static shape.
1828
1829  The capacity of a `StagingArea` may be bounded or unbounded.
1830  It supports multiple concurrent producers and consumers; and
1831  provides exactly-once delivery.
1832
1833  Each element of a `StagingArea` is a fixed-length tuple of tensors whose
1834  dtypes are described by `dtypes`, and whose shapes are optionally described
1835  by the `shapes` argument.
1836
1837  If the `shapes` argument is specified, each component of a staging area
1838  element must have the respective fixed shape. If it is
1839  unspecified, different elements may have different shapes,
1840
1841  It can be configured with a capacity in which case
1842  put(values) will block until space becomes available.
1843
1844  Similarly, it can be configured with a memory limit which
1845  will block put(values) until space is available.
1846  This is mostly useful for limiting the number of tensors on
1847  devices such as GPUs.
1848
1849  All get() and peek() commands block if the requested data
1850  is not present in the Staging Area.
1851
1852  """
1853
1854  def __init__(self,
1855               dtypes,
1856               shapes=None,
1857               names=None,
1858               shared_name=None,
1859               capacity=0,
1860               memory_limit=0):
1861    """Constructs a staging area object.
1862
1863    The two optional lists, `shapes` and `names`, must be of the same length
1864    as `dtypes` if provided.  The values at a given index `i` indicate the
1865    shape and name to use for the corresponding queue component in `dtypes`.
1866
1867    The device scope at the time of object creation determines where the
1868    storage for the `StagingArea` will reside.  Calls to `put` will incur a copy
1869    to this memory space, if necessary.  Tensors returned by `get` will be
1870    placed according to the device scope when `get` is called.
1871
1872    Args:
1873      dtypes:  A list of types.  The length of dtypes must equal the number
1874        of tensors in each element.
1875      shapes: (Optional.) Constraints on the shapes of tensors in an element.
1876        A list of shape tuples or None. This list is the same length
1877        as dtypes.  If the shape of any tensors in the element are constrained,
1878        all must be; shapes can be None if the shapes should not be constrained.
1879      names: (Optional.) If provided, the `get()` and
1880        `put()` methods will use dictionaries with these names as keys.
1881        Must be None or a list or tuple of the same length as `dtypes`.
1882      shared_name: (Optional.) A name to be used for the shared object. By
1883        passing the same name to two different python objects they will share
1884        the underlying staging area. Must be a string.
1885      capacity: (Optional.) Maximum number of elements.
1886        An integer. If zero, the Staging Area is unbounded
1887      memory_limit: (Optional.) Maximum number of bytes of all tensors
1888        in the Staging Area.
1889        An integer. If zero, the Staging Area is unbounded
1890
1891    Raises:
1892      ValueError: If one of the arguments is invalid.
1893    """
1894
1895    super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
1896                                      capacity, memory_limit)
1897
1898  def put(self, values, name=None):
1899    """Create an op that places a value into the staging area.
1900
1901    This operation will block if the `StagingArea` has reached
1902    its capacity.
1903
1904    Args:
1905      values: A single tensor, a list or tuple of tensors, or a dictionary with
1906        tensor values. The number of elements must match the length of the
1907        list provided to the dtypes argument when creating the StagingArea.
1908      name: A name for the operation (optional).
1909
1910    Returns:
1911        The created op.
1912
1913    Raises:
1914      ValueError: If the number or type of inputs don't match the staging area.
1915    """
1916    with ops.name_scope(name, "%s_put" % self._name,
1917                        self._scope_vals(values)) as scope:
1918
1919      if not isinstance(values, (list, tuple, dict)):
1920        values = [values]
1921
1922      # Hard-code indices for this staging area
1923      indices = list(range(len(values)))
1924      vals, _ = self._check_put_dtypes(values, indices)
1925
1926      with ops.colocate_with(self._coloc_op):
1927        op = gen_data_flow_ops.stage(
1928            values=vals,
1929            shared_name=self._name,
1930            name=scope,
1931            capacity=self._capacity,
1932            memory_limit=self._memory_limit)
1933
1934      return op
1935
1936  def __internal_get(self, get_fn, name):
1937    with ops.colocate_with(self._coloc_op):
1938      ret = get_fn()
1939
1940    indices = list(range(len(self._dtypes)))  # Hard coded
1941    return self._get_return_value(ret, indices)
1942
1943  def get(self, name=None):
1944    """Gets one element from this staging area.
1945
1946    If the staging area is empty when this operation executes, it will block
1947    until there is an element to dequeue.
1948
1949    Note that unlike others ops that can block, like the queue Dequeue
1950    operations, this can stop other work from happening.  To avoid this, the
1951    intended use is for this to be called only when there will be an element
1952    already available.  One method for doing this in a training loop would be to
1953    run a `put()` call during a warmup session.run call, and then call both
1954    `get()` and `put()` in each subsequent step.
1955
1956    The placement of the returned tensor will be determined by the current
1957    device scope when this function is called.
1958
1959    Args:
1960      name: A name for the operation (optional).
1961
1962    Returns:
1963      The tuple of tensors that was gotten.
1964    """
1965    if name is None:
1966      name = "%s_get" % self._name
1967
1968    # pylint: disable=bad-continuation
1969    fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
1970                    shared_name=self._name, name=name,
1971                    capacity=self._capacity,
1972                    memory_limit=self._memory_limit)
1973    # pylint: enable=bad-continuation
1974
1975    return self.__internal_get(fn, name)
1976
1977  def peek(self, index, name=None):
1978    """Peeks at an element in the staging area.
1979
1980    If the staging area is too small to contain the element at
1981    the specified index, it will block until enough elements
1982    are inserted to complete the operation.
1983
1984    The placement of the returned tensor will be determined by
1985    the current device scope when this function is called.
1986
1987    Args:
1988      index: The index of the tensor within the staging area
1989              to look up.
1990      name: A name for the operation (optional).
1991
1992    Returns:
1993      The tuple of tensors that was gotten.
1994    """
1995    if name is None:
1996      name = "%s_peek" % self._name
1997
1998    # pylint: disable=bad-continuation
1999    fn = lambda: gen_data_flow_ops.stage_peek(index,
2000                    dtypes=self._dtypes, shared_name=self._name,
2001                    name=name, capacity=self._capacity,
2002                    memory_limit=self._memory_limit)
2003    # pylint: enable=bad-continuation
2004
2005    return self.__internal_get(fn, name)
2006
2007  def size(self, name=None):
2008    """Returns the number of elements in the staging area.
2009
2010    Args:
2011        name: A name for the operation (optional)
2012
2013    Returns:
2014        The created op
2015    """
2016    if name is None:
2017      name = "%s_size" % self._name
2018
2019    return gen_data_flow_ops.stage_size(
2020        name=name,
2021        shared_name=self._name,
2022        dtypes=self._dtypes,
2023        capacity=self._capacity,
2024        memory_limit=self._memory_limit)
2025
2026  def clear(self, name=None):
2027    """Clears the staging area.
2028
2029    Args:
2030        name: A name for the operation (optional)
2031
2032    Returns:
2033        The created op
2034    """
2035    if name is None:
2036      name = "%s_clear" % self._name
2037
2038    return gen_data_flow_ops.stage_clear(
2039        name=name,
2040        shared_name=self._name,
2041        dtypes=self._dtypes,
2042        capacity=self._capacity,
2043        memory_limit=self._memory_limit)
2044
2045
2046class MapStagingArea(BaseStagingArea):
2047  """A `MapStagingArea` is a TensorFlow data structure that stores tensors
2048  across multiple steps, and exposes operations that can put and get tensors.
2049
2050  Each `MapStagingArea` element is a (key, value) pair.
2051  Only int64 keys are supported, other types should be
2052  hashed to produce a key.
2053  Values are a tuple of one or more tensors.
2054  Each tuple component has a static dtype,
2055  and may have a static shape.
2056
2057  The capacity of a `MapStagingArea` may be bounded or unbounded.
2058  It supports multiple concurrent producers and consumers; and
2059  provides exactly-once delivery.
2060
2061  Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
2062  whose
2063  dtypes are described by `dtypes`, and whose shapes are optionally described
2064  by the `shapes` argument.
2065
2066  If the `shapes` argument is specified, each component of a staging area
2067  element must have the respective fixed shape. If it is
2068  unspecified, different elements may have different shapes,
2069
2070  It behaves like an associative container with support for:
2071
2072   - put(key, values)
2073   - peek(key)         like dict.get(key)
2074   - get(key)          like dict.pop(key)
2075   - get(key=None)     like dict.popitem()
2076   - size()
2077   - clear()
2078
2079  If ordered a tree structure ordered by key will be used and
2080  get(key=None) will remove (key, value) pairs in increasing key order.
2081  Otherwise a hashtable
2082
2083  It can be configured with a capacity in which case
2084  put(key, values) will block until space becomes available.
2085
2086  Similarly, it can be configured with a memory limit which
2087  will block put(key, values) until space is available.
2088  This is mostly useful for limiting the number of tensors on
2089  devices such as GPUs.
2090
2091  All get() and peek() commands block if the requested
2092  (key, value) pair is not present in the staging area.
2093
2094  Partial puts are supported and will be placed in an incomplete
2095  map until such time as all values associated with the key have
2096  been inserted. Once completed, this (key, value) pair will be
2097  inserted into the map. Data in the incomplete map
2098  counts towards the memory limit, but not towards capacity limit.
2099
2100  Partial gets from the map are also supported.
2101  This removes the partially requested tensors from the entry,
2102  but the entry is only removed from the map once all tensors
2103  associated with it are removed.
2104  """
2105
2106  def __init__(self,
2107               dtypes,
2108               shapes=None,
2109               names=None,
2110               shared_name=None,
2111               ordered=False,
2112               capacity=0,
2113               memory_limit=0):
2114    """Args:
2115
2116      dtypes:  A list of types.  The length of dtypes must equal the number
2117        of tensors in each element.
2118      capacity: (Optional.) Maximum number of elements.
2119        An integer. If zero, the Staging Area is unbounded
2120      memory_limit: (Optional.) Maximum number of bytes of all tensors
2121        in the Staging Area (excluding keys).
2122        An integer. If zero, the Staging Area is unbounded
2123      ordered: (Optional.) If True the underlying data structure
2124        is a tree ordered on key. Otherwise assume a hashtable.
2125      shapes: (Optional.) Constraints on the shapes of tensors in an element.
2126        A list of shape tuples or None. This list is the same length
2127        as dtypes.  If the shape of any tensors in the element are constrained,
2128        all must be; shapes can be None if the shapes should not be constrained.
2129      names: (Optional.) If provided, the `get()` and
2130        `put()` methods will use dictionaries with these names as keys.
2131        Must be None or a list or tuple of the same length as `dtypes`.
2132      shared_name: (Optional.) A name to be used for the shared object. By
2133        passing the same name to two different python objects they will share
2134        the underlying staging area. Must be a string.
2135
2136    Raises:
2137      ValueError: If one of the arguments is invalid.
2138
2139    """
2140
2141    super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
2142                                         capacity, memory_limit)
2143
2144    # Defer to different methods depending if the map is ordered
2145    self._ordered = ordered
2146
2147    if ordered:
2148      self._put_fn = gen_data_flow_ops.ordered_map_stage
2149      self._pop_fn = gen_data_flow_ops.ordered_map_unstage
2150      self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
2151      self._peek_fn = gen_data_flow_ops.ordered_map_peek
2152      self._size_fn = gen_data_flow_ops.ordered_map_size
2153      self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
2154      self._clear_fn = gen_data_flow_ops.ordered_map_clear
2155    else:
2156      self._put_fn = gen_data_flow_ops.map_stage
2157      self._pop_fn = gen_data_flow_ops.map_unstage
2158      self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
2159      self._peek_fn = gen_data_flow_ops.map_peek
2160      self._size_fn = gen_data_flow_ops.map_size
2161      self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
2162      self._clear_fn = gen_data_flow_ops.map_clear
2163
2164  def put(self, key, vals, indices=None, name=None):
2165    """Create an op that stores the (key, vals) pair in the staging area.
2166
2167    Incomplete puts are possible, preferably using a dictionary for vals
2168    as the appropriate dtypes and shapes can be inferred from the value names
2169    dictionary key values. If vals is a list or tuple, indices must
2170    also be specified so that the op knows at which element position
2171    to perform the insert.
2172
2173    This operation will block if the capacity or memory limit of this
2174    container is reached.
2175
2176    Args:
2177        key: Key associated with the data
2178        vals: Tensor (or a dict/tuple of Tensors) to place
2179                into the staging area.
2180        indices: (Optional) if vals is a tuple/list, this is required.
2181        name: A name for the operation (optional)
2182
2183    Returns:
2184        The created op
2185
2186    Raises:
2187        ValueError: If the number or type of inputs don't match the staging
2188        area.
2189    """
2190
2191    with ops.name_scope(name, "%s_put" % self._name,
2192                        self._scope_vals(vals)) as scope:
2193
2194      vals, indices = self._check_put_dtypes(vals, indices)
2195
2196      with ops.colocate_with(self._coloc_op):
2197        op = self._put_fn(
2198            key,
2199            indices,
2200            vals,
2201            dtypes=self._dtypes,
2202            shared_name=self._name,
2203            name=scope,
2204            capacity=self._capacity,
2205            memory_limit=self._memory_limit)
2206    return op
2207
2208  def _get_indices_and_dtypes(self, indices=None):
2209    if indices is None:
2210      indices = list(range(len(self._dtypes)))
2211
2212    if not isinstance(indices, (tuple, list)):
2213      raise TypeError(f"Invalid indices type {type(indices)}")
2214
2215    if len(indices) == 0:
2216      raise ValueError("Empty indices")
2217
2218    if all(isinstance(i, str) for i in indices):
2219      if self._names is None:
2220        raise ValueError(f"String indices provided {indices}, but "
2221                         "this Staging Area was not created with names.")
2222
2223      try:
2224        indices = [self._names.index(n) for n in indices]
2225      except ValueError:
2226        raise ValueError(f"Named index not in "
2227                         f"Staging Area names {self._names}")
2228    elif all(isinstance(i, int) for i in indices):
2229      pass
2230    else:
2231      raise TypeError(f"Mixed types in indices {indices}. "
2232                      "May only be str or int")
2233
2234    dtypes = [self._dtypes[i] for i in indices]
2235
2236    return indices, dtypes
2237
2238  def peek(self, key, indices=None, name=None):
2239    """Peeks at staging area data associated with the key.
2240
2241    If the key is not in the staging area, it will block
2242    until the associated (key, value) is inserted.
2243
2244    Args:
2245        key: Key associated with the required data
2246        indices: Partial list of tensors to retrieve (optional).
2247                A list of integer or string indices.
2248                String indices are only valid if the Staging Area
2249                has names associated with it.
2250        name: A name for the operation (optional)
2251
2252    Returns:
2253        The created op
2254    """
2255
2256    if name is None:
2257      name = "%s_pop" % self._name
2258
2259    indices, dtypes = self._get_indices_and_dtypes(indices)
2260
2261    with ops.colocate_with(self._coloc_op):
2262      result = self._peek_fn(
2263          key,
2264          shared_name=self._name,
2265          indices=indices,
2266          dtypes=dtypes,
2267          name=name,
2268          capacity=self._capacity,
2269          memory_limit=self._memory_limit)
2270
2271    return self._get_return_value(result, indices)
2272
2273  def get(self, key=None, indices=None, name=None):
2274    """If the key is provided, the associated (key, value) is returned from the staging area.
2275
2276    If the key is not in the staging area, this method will block until
2277    the associated (key, value) is inserted.
2278    If no key is provided and the staging area is ordered,
2279    the (key, value) with the smallest key will be returned.
2280    Otherwise, a random (key, value) will be returned.
2281
2282    If the staging area is empty when this operation executes,
2283    it will block until there is an element to dequeue.
2284
2285    Args:
2286        key: Key associated with the required data (Optional)
2287        indices: Partial list of tensors to retrieve (optional).
2288                A list of integer or string indices.
2289                String indices are only valid if the Staging Area
2290                has names associated with it.
2291        name: A name for the operation (optional)
2292
2293    Returns:
2294        The created op
2295    """
2296    if key is None:
2297      return self._popitem(indices=indices, name=name)
2298    else:
2299      return self._pop(key, indices=indices, name=name)
2300
2301  def _pop(self, key, indices=None, name=None):
2302    """Remove and return the associated (key, value) is returned from the staging area.
2303
2304    If the key is not in the staging area, this method will block until
2305    the associated (key, value) is inserted.
2306    Args:
2307        key: Key associated with the required data
2308        indices: Partial list of tensors to retrieve (optional).
2309                A list of integer or string indices.
2310                String indices are only valid if the Staging Area
2311                has names associated with it.
2312        name: A name for the operation (optional)
2313
2314    Returns:
2315        The created op
2316    """
2317    if name is None:
2318      name = "%s_get" % self._name
2319
2320    indices, dtypes = self._get_indices_and_dtypes(indices)
2321
2322    with ops.colocate_with(self._coloc_op):
2323      result = self._pop_fn(
2324          key,
2325          shared_name=self._name,
2326          indices=indices,
2327          dtypes=dtypes,
2328          name=name,
2329          capacity=self._capacity,
2330          memory_limit=self._memory_limit)
2331
2332    return key, self._get_return_value(result, indices)
2333
2334  def _popitem(self, indices=None, name=None):
2335    """If the staging area is ordered, the (key, value) with the smallest key will be returned.
2336
2337    Otherwise, a random (key, value) will be returned.
2338    If the staging area is empty when this operation executes,
2339    it will block until there is an element to dequeue.
2340
2341    Args:
2342        key: Key associated with the required data
2343        indices: Partial list of tensors to retrieve (optional).
2344                A list of integer or string indices.
2345                String indices are only valid if the Staging Area
2346                has names associated with it.
2347        name: A name for the operation (optional)
2348
2349    Returns:
2350        The created op
2351    """
2352    if name is None:
2353      name = "%s_get_nokey" % self._name
2354
2355    indices, dtypes = self._get_indices_and_dtypes(indices)
2356
2357    with ops.colocate_with(self._coloc_op):
2358      key, result = self._popitem_fn(
2359          shared_name=self._name,
2360          indices=indices,
2361          dtypes=dtypes,
2362          name=name,
2363          capacity=self._capacity,
2364          memory_limit=self._memory_limit)
2365
2366    # Separate keys and results out from
2367    # underlying namedtuple
2368    key = self._create_device_transfers(key)[0]
2369    result = self._get_return_value(result, indices)
2370
2371    return key, result
2372
2373  def size(self, name=None):
2374    """Returns the number of elements in the staging area.
2375
2376    Args:
2377        name: A name for the operation (optional)
2378
2379    Returns:
2380        The created op
2381    """
2382    if name is None:
2383      name = "%s_size" % self._name
2384
2385    return self._size_fn(
2386        shared_name=self._name,
2387        name=name,
2388        dtypes=self._dtypes,
2389        capacity=self._capacity,
2390        memory_limit=self._memory_limit)
2391
2392  def incomplete_size(self, name=None):
2393    """Returns the number of incomplete elements in the staging area.
2394
2395    Args:
2396        name: A name for the operation (optional)
2397
2398    Returns:
2399        The created op
2400    """
2401    if name is None:
2402      name = "%s_incomplete_size" % self._name
2403
2404    return self._incomplete_size_fn(
2405        shared_name=self._name,
2406        name=name,
2407        dtypes=self._dtypes,
2408        capacity=self._capacity,
2409        memory_limit=self._memory_limit)
2410
2411  def clear(self, name=None):
2412    """Clears the staging area.
2413
2414    Args:
2415        name: A name for the operation (optional)
2416
2417    Returns:
2418        The created op
2419    """
2420    if name is None:
2421      name = "%s_clear" % self._name
2422
2423    return self._clear_fn(
2424        shared_name=self._name,
2425        name=name,
2426        dtypes=self._dtypes,
2427        capacity=self._capacity,
2428        memory_limit=self._memory_limit)
2429
2430
2431class RecordInput:
2432  """RecordInput asynchronously reads and randomly yields TFRecords.
2433
2434  A RecordInput Op will continuously read a batch of records asynchronously
2435  into a buffer of some fixed capacity. It can also asynchronously yield
2436  random records from this buffer.
2437
2438  It will not start yielding until at least `buffer_size / 2` elements have been
2439  placed into the buffer so that sufficient randomization can take place.
2440
2441  The order the files are read will be shifted each epoch by `shift_amount` so
2442  that the data is presented in a different order every epoch.
2443  """
2444
2445  def __init__(self,
2446               file_pattern,
2447               batch_size=1,
2448               buffer_size=1,
2449               parallelism=1,
2450               shift_ratio=0,
2451               seed=0,
2452               name=None,
2453               batches=None,
2454               compression_type=None):
2455    """Constructs a RecordInput Op.
2456
2457    Args:
2458      file_pattern: File path to the dataset, possibly containing wildcards.
2459        All matching files will be iterated over each epoch.
2460      batch_size: How many records to return at a time.
2461      buffer_size: The maximum number of records the buffer will contain.
2462      parallelism: How many reader threads to use for reading from files.
2463      shift_ratio: What percentage of the total number files to move the start
2464        file forward by each epoch.
2465      seed: Specify the random number seed used by generator that randomizes
2466        records.
2467      name: Optional name for the operation.
2468      batches: None by default, creating a single batch op. Otherwise specifies
2469        how many batches to create, which are returned as a list when
2470        `get_yield_op()` is called. An example use case is to split processing
2471        between devices on one computer.
2472      compression_type: The type of compression for the file. Currently ZLIB and
2473        GZIP are supported. Defaults to none.
2474
2475    Raises:
2476      ValueError: If one of the arguments is invalid.
2477    """
2478    self._batch_size = batch_size
2479    if batches is not None:
2480      self._batch_size *= batches
2481    self._batches = batches
2482    self._file_pattern = file_pattern
2483    self._buffer_size = buffer_size
2484    self._parallelism = parallelism
2485    self._shift_ratio = shift_ratio
2486    self._seed = seed
2487    self._name = name
2488    self._compression_type = python_io.TFRecordCompressionType.NONE
2489    if compression_type is not None:
2490      self._compression_type = compression_type
2491
2492  def get_yield_op(self):
2493    """Adds a node that yields a group of records every time it is executed.
2494    If RecordInput `batches` parameter is not None, it yields a list of
2495    record batches with the specified `batch_size`.
2496    """
2497    compression_type = python_io.TFRecordOptions.get_compression_type_string(
2498        python_io.TFRecordOptions(self._compression_type))
2499    records = gen_data_flow_ops.record_input(
2500        file_pattern=self._file_pattern,
2501        file_buffer_size=self._buffer_size,
2502        file_parallelism=self._parallelism,
2503        file_shuffle_shift_ratio=self._shift_ratio,
2504        batch_size=self._batch_size,
2505        file_random_seed=self._seed,
2506        compression_type=compression_type,
2507        name=self._name)
2508    if self._batches is None:
2509      return records
2510    else:
2511      with ops.name_scope(self._name):
2512        batch_list = [[] for _ in range(self._batches)]
2513        records = array_ops.split(records, self._batch_size, 0)
2514        for index, protobuf in enumerate(records):
2515          batch_index = index % self._batches
2516          batch_list[batch_index].append(array_ops.reshape(protobuf, []))
2517        return batch_list
2518