xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_feed.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""Helper library for handling infeed between hosts and TPUs.
17"""
18
19import itertools
20
21import numpy as np
22
23from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.tpu import tpu_name_util
29from tensorflow.python.tpu import tpu_sharding
30from tensorflow.python.tpu.ops import tpu_ops
31
32from tensorflow.python.util import nest
33
34
35def partition_or_replicate_on_host(tensor, dims):
36  """Partitions or replicates the input tensor.
37
38    The ops inside this function are placed on the host side.
39
40  Args:
41    tensor: The input tensor which will be partitioned or replicated.
42    dims: A list of integer describes how to partition the input tensor.
43
44  Returns:
45    An iterator of `Tensor`s or a list of partitioned tensors.
46  """
47  if dims is None:
48    return itertools.repeat(tensor)
49  dims = np.array(dims)
50  output = [tensor]
51  shape_list = np.array(tensor.shape.as_list())
52  quotients, remainders = np.divmod(shape_list, dims)
53  for axis, (quotient, remainder, dim, original_size) in enumerate(
54      zip(quotients, remainders, dims, shape_list)):
55    if dim <= 1:
56      continue
57    if remainder > 0:
58      # For each dimension, when it cannot be evenly partitioned, XLA assumes
59      # tensors are partitioned in a greedy manner by using
60      # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
61      # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
62      # [[(3, 4), (3, 4), (2, 4), (2, 2)],
63      # [(2, 4), (2, 4), (2, 4), (2, 2)]]
64      ceil_ratio = quotient + 1
65      num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
66      num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
67      if len(num_or_size_splits) < dim:
68        num_or_size_splits += [0] * (dim - len(num_or_size_splits))
69      new_output = []
70      for x in output:
71        new_output.append(
72            array_ops.split(
73                x, num_or_size_splits=num_or_size_splits, axis=axis))
74      output = new_output
75    else:
76      output = [array_ops.split(x, int(dim), axis=axis) for x in output]
77    output = nest.flatten(output)
78  return output
79
80
81def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
82  """Tags appropriate XLA sharding attribute to the dequeued tensor.
83
84  The sharding attribute of the dequeued tensor will be a tuple.
85
86  Args:
87    tensor: The dequeued tensor on TPU.
88    dims: A list of integer describes how the tensor is partitioned.
89
90  Returns:
91    The same tensor with the xla_sharding attribute.
92  """
93  if dims is None:
94    return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
95  elif np.prod(dims) == 1:
96    return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
97  else:
98    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
99    return xla_sharding.tile(
100        tensor=tensor,
101        tile_assignment=tile_assignment,
102        assign_tuple_sharding=True)
103
104
105def tag_sharding_attribute_for_dequeued_tensors(dequeues, dims):
106  """Tags appropriate XLA sharding attribute to the dequeued tensors.
107
108  Args:
109    dequeues: A list of dequeued tensors on TPU.
110    dims: A list of integer describes how the tensor is partitioned.
111
112  Returns:
113    The same dequeues with appropriate xla_sharding attribute.
114  """
115  nest.assert_shallow_structure(dequeues, dims)
116  return nest.map_structure_up_to(
117      dequeues, _tag_sharding_attribute_for_dequeued_tensor, dequeues, dims)
118
119
120class InfeedQueue(object):
121  """A helper object to build a device infeed queue.
122
123  The InfeedQueue builds the host-side and device-side Ops to enqueue and
124  dequeue elements, respectively, and ensures that their types and
125  shapes match.
126  """
127
128  def __init__(self,
129               number_of_tuple_elements=None,
130               tuple_types=None,
131               tuple_shapes=None,
132               shard_dimensions=None,
133               number_of_partitions=None,
134               name=None):
135    """Creates a new InfeedQueue with the given configuration.
136
137    The configuration need not be fully specified at creation since it
138    can be modified subsequently by methods that set the values
139    explicitly or infer them from the shapes of inputs.
140
141    Args:
142      number_of_tuple_elements: the number of Tensors fed atomically through the
143        queue, must be present unless it can be inferred from other arguments.
144      tuple_types: if not None, a list of types of the elements of the queue.
145      tuple_shapes: if not None, a list of shapes of the elements of the queue.
146      shard_dimensions: if not None, a list of dimensions on which the
147        elements of the queue should be sharded during automatic
148        parallelization.
149      number_of_partitions: if > 1, the infeed dequeue shape will contain
150        the full shape that includes all partitions and add corresponding XLA
151        annotation on the infeed dequeue op. In this case, the infeed is still
152        data parallel that feeds per-core batch size to each core while the XLA
153        computation may be partitioned. As XLA requires infeed dequeue shape to
154        be per-replica shape, thus we need number_of_partitions here to
155        calculate the per-replica unpartitioned shape.
156      name: the name of the queue.
157
158    Raises:
159      ValueError: if number_of_tuple_elements <= 0; or
160        number_of_tuple_arguments, tuple_types, tuple_shapes, and
161        shard_dimensions are all None; or the length of tuple_types,
162        tuple_shapes, or shard_dimensions is not equal to
163        number_of_tuple_elements; or any element of shard_dimensions
164        can't be converted to a Dimension.
165      TypeError: if any element of tuple_types or tuple_shapes can't
166        be converted to a dtype or TensorShape, respectively.
167    """
168    self._frozen = False
169    self._generated_enqueue_ops = False
170    self._generated_dequeue_op = False
171    self._name = "InfeedQueue" if name is None else name
172    if number_of_partitions is None:
173      self._number_of_partitions = 1
174    else:
175      self._number_of_partitions = number_of_partitions
176    if number_of_tuple_elements is None:
177      if tuple_types is not None:
178        number_of_tuple_elements = len(tuple_types)
179      elif tuple_shapes is not None:
180        number_of_tuple_elements = len(tuple_shapes)
181      elif shard_dimensions is not None:
182        number_of_tuple_elements = len(shard_dimensions)
183      else:
184        raise ValueError(
185            "number of tuple elements cannot be inferred from InfeedQueue "
186            "constructor")
187    if number_of_tuple_elements <= 0:
188      raise ValueError(f"number_of_tuple_elements {number_of_tuple_elements} "
189                       "must be > 0")
190    # Make an empty sharding policy for each tuple element.
191    self._sharding_policies = [
192        tpu_sharding.ShardingPolicy() for _ in range(number_of_tuple_elements)
193    ]
194    if tuple_types is not None:
195      self.set_tuple_types(tuple_types)
196    else:
197      self._tuple_types = None
198    if tuple_shapes is not None:
199      self.set_tuple_shapes(tuple_shapes)
200    else:
201      self._tuple_shapes = None
202    if shard_dimensions is not None:
203      self.set_shard_dimensions(shard_dimensions)
204    self._validate()
205
206  def _validate(self):
207    """Checks that the configuration is self-consistent.
208
209    Raises:
210      ValueError: if the shapes and sharding policies don't match.
211    """
212    if self.tuple_shapes is not None:
213      for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
214        # Raise an error if the policy is incompatible with the shape.
215        _ = policy.get_sharded_shape(shape)
216
217  @property
218  def number_of_tuple_elements(self):
219    """Returns the number of InfeedQueue tuple elements."""
220    return len(self._sharding_policies)
221
222  @property
223  def tuple_types(self):
224    """Returns the types of the InfeedQueue tuple elements."""
225    return self._tuple_types
226
227  def set_tuple_types(self, tuple_types):
228    """Sets the type of each element of the queue.
229
230    tuple_types must be a list of length
231    self.number_of_tuple_elements, and each element must be
232    convertible to a dtype.
233
234    Args:
235      tuple_types: the types of each queue element.
236
237    Raises:
238      ValueError: if tuple_types is not of length
239        self.number_of_tuple_elements.
240      TypeError: if an element of tuple_types cannot be converted to a
241        dtype.
242    """
243    if len(tuple_types) != self.number_of_tuple_elements:
244      raise ValueError(
245          f"tuple_types is {str(tuple_types)}, but must be a list of "
246          f"length {self.number_of_tuple_elements}"
247      )
248    if self._frozen:
249      for (frozen, updated) in zip(self._tuple_types, tuple_types):
250        if frozen != updated:
251          raise ValueError(
252              "Trying to update InfeedQueue with frozen configuration with an "
253              f"incompatible type. Frozen types are {str(self._tuple_types)}, "
254              f"updated types are {str(tuple_types)}")
255    else:
256      try:
257        self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
258      except (TypeError) as e:
259        raise TypeError(
260            f"tuple_types is {str(tuple_types)}, but must be a list of "
261            f"elements each convertible to dtype: got error {str(e)}") from e
262
263  @property
264  def tuple_shapes(self):
265    """Returns the shapes of the InfeedQueue tuple elements."""
266    return self._tuple_shapes
267
268  def set_tuple_shapes(self, tuple_shapes):
269    """Sets the shape of each element of the queue.
270
271    tuple_shapes must be a list of length
272    self.number_of_tuple_elements, and each element must be
273    convertible to a TensorShape.
274
275    Args:
276      tuple_shapes: the shapes of each queue element.
277
278    Raises:
279      ValueError: if tuple_shapes is not of length
280        self.number_of_tuple_elements.
281      TypeError: if an element of tuple_shapes cannot be converted to
282        a TensorShape.
283    """
284    if len(tuple_shapes) != self.number_of_tuple_elements:
285      raise ValueError(
286          f"tuple_shapes is {str(tuple_shapes)}, but must be a list of "
287          f"length {self.number_of_tuple_elements}"
288      )
289    try:
290      tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
291    except (ValueError, TypeError) as e:
292      raise TypeError(
293          f"tuple_shapes is {str(tuple_shapes)}, but must be a list of "
294          "elements each convertible to TensorShape: got error "
295          f"{str(e)}") from e
296    if self._frozen:
297      for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
298        if frozen != updated:
299          raise ValueError(
300              "Trying to update InfeedQueue with frozen configuration with an "
301              "incompatible shape. Frozen shapes are "
302              f"{str(self._tuple_shapes)}, updated shapes are "
303              f"{str(tuple_shapes)}")
304
305    else:
306      self._tuple_shapes = tuple_shapes
307    self._validate()
308
309  @property
310  def sharding_policies(self):
311    """Returns the sharding policies of the InfeedQueue tuple elements."""
312    return self._sharding_policies
313
314  @property
315  def shard_dimensions(self):
316    """Gets the shard dimension of each tuple element.
317
318    Returns:
319      A list of length number_of_tuple_elements, where each list entry
320      is the shard dimension of that tuple element or None if the
321      shard dimension has not been set.
322    """
323    # The number of shards is always the same for all the policies.
324    return [policy.shard_dimension for policy in self._sharding_policies]
325
326  def set_shard_dimensions(self, shard_dimensions):
327    """Sets the shard_dimension of each element of the queue.
328
329    shard_dimensions must be a list of length
330    self.number_of_tuple_elements, and each element must be
331    convertible to a Dimension compatible with self.tuple_shapes.
332
333    Args:
334      shard_dimensions: the dimensions of each queue element.
335
336    Raises:
337      ValueError: if shard_dimensions is not of length
338        self.number_of_tuple_elements; or an element of
339        shard_dimensions cannot be converted to a Dimension; or an
340        element of shard_dimensions is a Dimension that is out of
341        range for the corresponding tuple element shape.
342    """
343    if len(shard_dimensions) != self.number_of_tuple_elements:
344      raise ValueError(f"shard_dimensions is {str(shard_dimensions)}, but must "
345                       f"be a list of length {self.number_of_tuple_elements}")
346    for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
347      policy.set_shard_dimension(dimension)
348    self._validate()
349
350  @property
351  def number_of_shards(self):
352    """Gets the number of shards to use for the InfeedQueue.
353
354    Returns:
355      Number of shards or None if the number of shards has not been set.
356    """
357    # The number of shards is always the same for all the policies.
358    return self._sharding_policies[0].number_of_shards
359
360  def set_number_of_shards(self, number_of_shards):
361    """Sets the number of shards to use for the InfeedQueue.
362
363    Args:
364      number_of_shards: number of ways to shard the InfeedQueue.
365
366    Raises:
367      ValueError: if number_of_shards is not > 0; or the policies have
368        been frozen and number_of_shards was already set to something
369        else.
370    """
371    for policy in self._sharding_policies:
372      policy.set_number_of_shards(number_of_shards)
373      policy.set_number_of_partitions(self._number_of_partitions)
374    self._validate()
375
376  def set_configuration_from_input_tensors(self, input_tensors):
377    """Sets the shapes and types of the queue tuple elements.
378
379    input_tensors is a list of Tensors whose types and shapes are used
380    to set the queue configuration.
381
382    Args:
383      input_tensors: list of Tensors of the same types and shapes as
384        the desired queue Tuple.
385
386    Raises:
387      ValueError: if input_tensors is not a list of length
388        self.number_of_tuple_elements
389    """
390    if len(input_tensors) != self.number_of_tuple_elements:
391      raise ValueError(f"input_tensors is {str(input_tensors)}, but should be "
392                       f"a list of {self.number_of_tuple_elements} Tensors")
393    self.set_tuple_shapes([t.shape for t in input_tensors])
394    self.set_tuple_types([t.dtype for t in input_tensors])
395
396  def set_configuration_from_sharded_input_tensors(self, input_tensors):
397    """Sets the shapes and types of the queue tuple elements.
398
399    input_tensors is a list of lists of Tensors whose types and shapes are used
400    to set the queue configuration. The length of the outer list is the number
401    of shards required, and each inner list is the tuple of Tensors to use to
402    determine the types and shapes of the corresponding shard. This method
403    depends on the shard dimension, and calling it freezes the shard policy.
404
405    Args:
406      input_tensors: list of lists of Tensors. The outer list length corresponds
407        to the desired number of shards, and each inner list is the size
408        and shape of the desired configuration of the corresponding shard.
409
410    Raises:
411      ValueError: if any inner list is not a list of length
412        self.number_of_tuple_elements; or the inner lists do not combine to
413        form a consistent unsharded shape.
414      TypeError: if the types of the Tensors in the inner lists do not match.
415    """
416    if not self._frozen:
417      # Unset the tuple shapes in case the configuration becomes
418      # transiently inconsistent.
419      self._tuple_shapes = None
420    number_of_shards = len(input_tensors)
421    self.set_number_of_shards(number_of_shards)
422    for t in input_tensors:
423      if len(t) != self.number_of_tuple_elements:
424        raise ValueError(
425            f"input_tensors is {str(input_tensors)} but must be a list of "
426            "lists, where each inner list has length "
427            f"number_of_tuple_elements={self.number_of_tuple_elements}")
428    # Transpose the inputs to make a list of shard shapes for each tuple
429    # element.
430    sharded_shapes = [[t[i].shape
431                       for t in input_tensors]
432                      for i in range(self.number_of_tuple_elements)]
433    # For each tuple, get the unsharded shape using that tuple's policy.
434    unsharded_shapes = [
435        policy.get_unsharded_shape(s)
436        for (policy, s) in zip(self._sharding_policies, sharded_shapes)
437    ]
438    self.set_tuple_shapes(unsharded_shapes)
439    for i in range(1, self.number_of_shards):
440      for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
441        if t1.dtype != t2.dtype:
442          raise TypeError(
443              "types of the tuple elements of input_tensors "
444              f"{str(input_tensors)} are not consistent")
445    self.set_tuple_types([t.dtype for t in input_tensors[0]])
446
447  def freeze(self):
448    """Freezes the InfeedQueue so it can no longer be modified.
449
450    The configuration is implicitly frozen before any host-side or
451    device-side Ops are generated. The configuration cannot be frozen
452    until the types and shapes of the tuple elements have been set.
453
454    Raises:
455      ValueError: if the types or shapes of the tuple elements have not been
456      set.
457    """
458    self._frozen = True
459    if self._tuple_types is None:
460      raise ValueError(
461          "Can't freeze an InfeedQueue without setting all tuple types.")
462    if self._tuple_shapes is None:
463      raise ValueError(
464          "Can't freeze an InfeedQueue without setting all tuple shapes.")
465    for shape in self._tuple_shapes:
466      if shape.dims is None:
467        raise ValueError(
468            "Can't freeze an InfeedQueue without setting all tuple shapes.")
469    for policy in self._sharding_policies:
470      policy.freeze()
471    self._validate()
472
473  def generate_dequeue_op(self, tpu_device=0):
474    """Generates the device-side Op to dequeue a tuple from the queue.
475
476    Implicitly freezes the queue configuration if it is not already
477    frozen, which will raise errors if the shapes and types have not
478    been fully specified.
479
480    Args:
481      tpu_device: The TPU device ordinal where the infeed instruction should be
482        placed. If None, no explicit placement will be performed, and it is up
483        to the user to call this API from within a proper TPU device scope.
484        The XLA code will fail if the TPU dequeue instruction is not bound to
485        any device.
486
487    Returns:
488      A list of Outputs corresponding to a shard of infeed dequeued
489      into XLA, suitable for use within a replicated block.
490
491    Raises:
492      ValueError: if the types or shapes of the tuple elements have not been
493      set; or if a dequeue op has already been generated.
494    """
495    self.freeze()
496    if self._generated_dequeue_op and not ops.inside_function():
497      raise ValueError("Can't generate two dequeue Ops from the same queue")
498    self._generated_dequeue_op = True
499    full_name = "%s/dequeue" % self._name
500    sharded_shapes = [
501        policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
502        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
503    ]
504    if tpu_device is not None:
505      with ops.device(tpu_name_util.core(tpu_device)):
506        dequeue_op = tpu_ops.infeed_dequeue_tuple(
507            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
508    else:
509      dequeue_op = tpu_ops.infeed_dequeue_tuple(
510          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
511    if self._number_of_partitions <= 1:
512      return dequeue_op
513    partitions = [
514        policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
515        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
516    ]
517    return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
518
519  def _generate_enqueue_op(self,
520                           inputs,
521                           name_prefix,
522                           index,
523                           device=None,
524                           tpu_ordinal=-1):
525    """Generate a host-side Op to enqueue a tuple to the queue.
526
527    If device is None the inputs are all required to have the same
528    device specification, and the enqueue Op is colocated with
529    inputs[0]. Otherwise the enqueue Op is placed on 'device'.
530
531    Args:
532      inputs: a list of Tensors with the types and shapes of the tuple elements.
533      name_prefix: the base name for the Op.
534      index: the shard index, used to uniquify the Op name.
535      device: device to place the Op on, or None if it should be
536        colocated with the inputs.
537      tpu_ordinal: ordinal of the TPU device on the host to use for
538      infeed if device is a CPU device. Should be set to -1 if device
539      is a TPU device.
540
541    Returns:
542      An Op corresponding to a shard of infeed enqueued at the host,
543      suitable for use within a replicated block.
544
545    Raises:
546      ValueError: if device is None and inputs do not all have the
547        same device specification.
548    """
549    full_name = "%s/%d" % (name_prefix, index)
550    shapes = [t.shape for t in inputs]
551    if device is None:
552      devices = [t.device for t in inputs]
553      for i in range(1, self.number_of_tuple_elements):
554        if devices[0] != devices[i]:
555          raise ValueError(
556              f"input devices for shard {index} are {str(devices)}, but should "
557              "all be the same")
558      with ops.colocate_with(inputs[0]):
559        return tpu_ops.infeed_enqueue_tuple(
560            inputs=inputs,
561            shapes=shapes,
562            name=full_name,
563            device_ordinal=tpu_ordinal)
564    else:
565      with ops.device(device):
566        return tpu_ops.infeed_enqueue_tuple(
567            inputs=inputs,
568            shapes=shapes,
569            name=full_name,
570            device_ordinal=tpu_ordinal)
571
572  def generate_enqueue_ops(self,
573                           sharded_inputs,
574                           tpu_ordinal_function=None,
575                           placement_function=None):
576    """Generates the host-side Ops to enqueue the shards of a tuple.
577
578    sharded_inputs is a list, one for each shard, of lists of
579    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
580    shard i of the queue. Returns the host-side Ops that must be run to
581    enqueue the sharded tuple. The Op for shard i is colocated with the inputs
582    for shard i.
583
584    Implicitly freezes the queue configuration if it is not already
585    frozen. If the configuration has already been frozen, and is not
586    compatible with the types and shapes of sharded_inputs, an error
587    will be raised.
588
589    Args:
590      sharded_inputs: a list of lists of Tensors. The length of the outer list
591        determines the number of shards. Each inner list indicates the types
592        and shapes of the tuples in the corresponding shard.
593      tpu_ordinal_function: if not None, a function that takes the
594        shard index as input and returns the ordinal of the TPU device
595        the shard's infeed should be placed on. tpu_ordinal_function must be
596        set if the inputs are placed on CPU devices.
597      placement_function: if not None, a function that takes the shard index as
598        input and returns the host device where the enqueue op should be placed
599        on.
600
601    Returns:
602      A list of host-side Ops, one for each shard, that when executed together
603      will enqueue a full-size element of infeed.
604
605    Raises:
606      ValueError: if the queue configuration has previously been frozen and the
607        shapes of the elements of sharded_inputs are not compatible with the
608        frozen configuration; or if the shapes of the elements of sharded_inputs
609        don't form a consistent unsharded tuple; or if the elements of a tuple
610        have different device constraints.
611      TypeError: if the queue configuration has previously been frozen and the
612        types of the elements of sharded_inputs are not compatible with the
613        frozen configuration; or if the types of the elements of sharded_inputs
614        don't form a consistent unsharded tuple.
615    """
616    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
617    self.freeze()
618    if self._generated_enqueue_ops and not ops.inside_function():
619      raise ValueError("Can't generate two enqueue Ops from the same queue")
620    self._generated_enqueue_ops = True
621    if tpu_ordinal_function is None:
622      tpu_ordinal_function = lambda index: -1
623    name_prefix = "%s/enqueue" % self._name
624    return [
625        self._generate_enqueue_op(
626            shard,
627            name_prefix,
628            index,
629            tpu_ordinal=tpu_ordinal_function(index),
630            device=placement_function(index) if placement_function else None)
631        for (shard, index) in zip(sharded_inputs, range(self.number_of_shards))
632    ]
633
634  # TODO(misard) Generalize this to the case of systems that don't
635  # have 8 devices per host, and figure out what to do with
636  # model-parallelism.
637  def _default_placement_function(self, index):
638    return "/task:%d/device:CPU:0" % (index / 8)
639
640  def _default_ordinal_function(self, index):
641    return index % 8
642
643  # TODO(b/36470756) remove this from tutorials once we have a better story
644  # for automatic placement of input pipelines.
645  def split_inputs_and_generate_enqueue_ops(self,
646                                            inputs,
647                                            device_assignment=None,
648                                            placement_function=None,
649                                            tpu_ordinal_function=None):
650    """POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
651
652    Generates the host-side Ops to enqueue a tuple.
653
654    This method performs poorly because it takes an entire input on a single
655    host, splits it, and distributes it to all of the cores. It is present only
656    to simplify tutorial examples.
657
658    inputs is a list of Tensors to use to feed the queue. Each input is split
659    into self.number_of_shards shards. Returns an Op for each shard to enqueue
660    the shard. The Op for shard i is placed on device placement_function(i).
661
662    Implicitly freezes the queue configuration if it is not already
663    frozen. If the configuration has already been frozen, and is not
664    compatible with the types and shapes of inputs, an error
665    will be raised.
666
667    Args:
668      inputs: a list of Tensors which indicates the types and shapes of the
669        queue tuple.
670     device_assignment: if not `None`, a TPU `DeviceAssignment`. If
671        device_assignment is not `None`, but `placement_function` and
672        `ordinal_function` are None, then `device_assignment` will be used to
673        place infeeds on the first k TPU shards, where k is the number of shards
674        in the queue. If all three are `None`, then default placement and
675        ordinal functions are used.
676      placement_function: if not None, a function that takes the shard
677        index as input and returns a device string indicating which
678        device the shard's infeed should be placed on. If placement_function
679        and tpu_ordinal_function are None, inputs are sharded round-robin
680        across the devices in the system.
681      tpu_ordinal_function: if not None, a function that takes the
682        shard index as input and returns the ordinal of the TPU device
683        the shard's infeed should be placed on. If placement_function
684        and tpu_ordinal_function are None, inputs are sharded round-robin
685        across the devices in the system.
686
687    Returns:
688      A list of host-side Ops, one for each shard, that when executed together
689      will enqueue a full-size element of infeed.
690
691    Raises:
692      ValueError: if the queue configuration has previously been frozen and the
693        shapes of the elements of inputs are not compatible with the frozen
694        configuration.
695      TypeError: if the queue configuration has previously been frozen and the
696        types of the elements of inputs are not compatible with the frozen
697        configuration.
698    """
699    if device_assignment is None:
700      if placement_function is None:
701        placement_function = self._default_placement_function
702      if tpu_ordinal_function is None:
703        tpu_ordinal_function = self._default_ordinal_function
704    else:
705
706      def _placement_function_from_map(index):
707        return device_assignment.host_device(replica=index)
708
709      def _ordinal_function_from_map(index):
710        return device_assignment.tpu_ordinal(replica=index)
711
712      if placement_function is None:
713        placement_function = _placement_function_from_map
714      if tpu_ordinal_function is None:
715        tpu_ordinal_function = _ordinal_function_from_map
716    self.set_configuration_from_input_tensors(inputs)
717    self.freeze()
718    if self._generated_enqueue_ops and not ops.inside_function():
719      raise ValueError("Can't generate two enqueue Ops from the same queue")
720    self._generated_enqueue_ops = True
721    split_name_prefix = "%s/split" % self._name
722    if self.number_of_shards == 1:
723      transposed_sharded_inputs = [[inp] for inp in inputs]
724    else:
725
726      def split_fn(inp, num_shards, axis, name):
727        with ops.colocate_with(inp):
728          return array_ops.split(inp, num_shards, axis=axis, name=name)
729
730      transposed_sharded_inputs = [
731          split_fn(
732              inp,
733              self.number_of_shards,
734              axis=policy.shard_dimension,
735              name="%s/%d" % (split_name_prefix, index))
736          for (inp, policy, index) in zip(inputs, self._sharding_policies,
737                                          range(self.number_of_tuple_elements))
738      ]
739    sharded_inputs = [[shard[i]
740                       for shard in transposed_sharded_inputs]
741                      for i in range(self.number_of_shards)]
742    name_prefix = "%s/enqueue" % self._name
743    return [
744        self._generate_enqueue_op(
745            shard,
746            name_prefix,
747            index,
748            device=placement_function(index),
749            tpu_ordinal=tpu_ordinal_function(index))
750        for (shard, index) in zip(sharded_inputs, range(self.number_of_shards))
751    ]
752
753
754class _PartitionedInfeedQueue(InfeedQueue):
755  """A helper object to build a device infeed queue with input partition.
756
757  Args:
758    number_of_tuple_elements: the number of Tensors fed atomically through the
759      queue, must be present unless it can be inferred from other arguments.
760    device_assignment: A TPU `DeviceAssignment` which is used to place all the
761      partitions to different TPU infeed queues.
762    host_id: The id of the host machine.
763    input_partition_dims: A nested list/tuple of integers. Each inner
764      list/tuple describes how to partition the corresponding input tensor.
765    tuple_types: If not None, a list of types of the elements of the queue.
766    tuple_shapes: If not None, a list of shapes of the elements of the queue.
767    name: The name of the queue.
768  """
769
770  def __init__(self,
771               number_of_tuple_elements,
772               device_assignment,
773               host_id,
774               input_partition_dims=None,
775               tuple_types=None,
776               tuple_shapes=None,
777               name=None):
778    super(_PartitionedInfeedQueue, self).__init__(
779        number_of_tuple_elements=number_of_tuple_elements,
780        tuple_types=tuple_types,
781        tuple_shapes=None,
782        shard_dimensions=None,
783        name="PartitionedInfeedQueue" if name is None else name)
784    self._input_partition_dims = input_partition_dims
785    self._host_id = host_id
786    self._device_assignment = device_assignment
787
788  def generate_dequeue_op(self, tpu_device=0):
789    """Generate TPU dequeue ops.
790
791    Args:
792      tpu_device: The TPU device ordinal where the infeed instruction should be
793        placed.
794
795    Returns:
796      A list of Outputs corresponding to a partition of infeed dequeued
797      into XLA, suitable for use within a replicated block.
798
799    Raises:
800      ValueError: if the types or shapes of the tuple elements have not been
801      set; or if a dequeue op has already been generated.
802    """
803    self.freeze()
804    if self._generated_dequeue_op and not ops.inside_function():
805      raise ValueError("Can't generate two dequeue Ops from the same queue")
806    self._generated_dequeue_op = True
807    full_name = "%s/dequeue" % self._name
808    sharded_shapes = [
809        policy.get_sharded_shape(shape)
810        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
811    ]
812    with ops.device(tpu_name_util.core(tpu_device)):
813      values = tpu_ops.infeed_dequeue_tuple(
814          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
815    return tag_sharding_attribute_for_dequeued_tensors(
816        values, self._input_partition_dims)
817
818  def generate_enqueue_ops(self, sharded_inputs):
819    """Generates the host-side Ops to enqueue the partitioned inputs.
820
821    sharded_inputs is a list, one for each replica, of lists of
822    Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed
823    replica i.
824    sharded_inputs[i][j] is partitioned by self._input_partition_dims[j].
825
826    For example, if sharded_inputs[i][j] is a 2-D Tensor:
827    [[A, B, C, D],
828     [E ,F, G, H]]
829    self._input_partition_dims[j] is [2, 4].
830
831    sharded_inputs[i][j] will be partitioned and flattened into:
832    [A, B, C, D, E, F, G, H] and fed into the logical core ids:
833    [0, 1, 2, 3, 4, 5, 6, 7] respectively.
834
835    Args:
836      sharded_inputs: a list of lists of Tensors. The length of the
837        outer list determines the number of shards. Each inner list indicates
838        the types and shapes of the tuples in the corresponding shard.
839
840    Returns:
841      A list of host-side Ops, one for each shard, that when executed together
842      will enqueue a full-size element of infeed.
843
844    Raises:
845      ValueError: if the queue configuration has previously been frozen and the
846        shapes of the elements of sharded_inputs are not compatible with the
847        frozen configuration; or if the shapes of the elements of sharded_inputs
848        don't form a consistent unsharded tuple; or if the elements of a tuple
849        have different device constraints; or if the partition dims are invalid.
850      TypeError: if the queue configuration has previously been frozen and the
851        types of the elements of sharded_inputs are not compatible with the
852        frozen configuration; or if the types of the elements of sharded_inputs
853        don't form a consistent unsharded tuple.
854    """
855    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
856    number_of_replicas = len(sharded_inputs)
857    number_of_tuple_elements = len(sharded_inputs[0])
858
859    assert len(self._input_partition_dims) == number_of_tuple_elements
860    enqueue_ops = []
861
862    for replica_index in range(number_of_replicas):
863      flattened_inputs = sharded_inputs[replica_index]
864      inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs,
865                                                 self._input_partition_dims)
866      inputs_parted_iters = [
867          iter(self._check_dims_and_partition_or_replicate_on_host(x, dims))
868          for x, dims in zip(sharded_inputs[replica_index],
869                             inputs_part_dims_flat)
870      ]
871
872      # Find the replica_id of the host's logical core 0.
873      # The self._host_id is guaranteed to contain the logical core 0,
874      # even when num_cores_per_replica > num_cores_per_host -- the function
875      # caller makes sure that this host_id will must be receiving data (calls
876      # input_fn).
877      replica_id = self._device_assignment.lookup_replicas(
878          task_id=self._host_id, logical_core=0)[replica_index]
879      for logical_core in range(self._device_assignment.num_cores_per_replica):
880        # Places different partitions to different logic cores.
881        # Since there can be multiple hosts per replica, we need to find
882        # the actual host (device) of this logical core.
883        device = self._device_assignment.host_device(
884            replica=replica_id, logical_core=logical_core)
885
886        with ops.device(device):
887          ordinal = self._device_assignment.tpu_ordinal(
888              replica=replica_id, logical_core=logical_core)
889          infeed_inputs = []
890          for it in inputs_parted_iters:
891            input_for_device = next(it, None)
892            if input_for_device is not None:
893              infeed_inputs.append(input_for_device)
894
895          if infeed_inputs:
896            enqueue_ops.append(
897                tpu_ops.infeed_enqueue_tuple(
898                    inputs=infeed_inputs,
899                    shapes=[x.shape for x in infeed_inputs],
900                    name="enqueue/replica_{0}/input_{1}".format(
901                        replica_index, logical_core),
902                    device_ordinal=ordinal))
903    return enqueue_ops
904
905  def _check_input_partition_dims(self, tensor, dims):
906    """Checks that input partition dims are valid for the `Tensor`.
907
908    Args:
909      tensor: Input tensor for partitioning.
910      dims: A list of integer describes how to partition the input tensor.
911
912    Raises:
913      ValueError: If the tensor can't be partitioned by dims or the
914        num_cores_per_replica doesn't match the number of
915        partitions(dims.prod()).
916    """
917    # No partitioning specified, so don't perform further checks.
918    if dims is None:
919      return
920
921    dims = np.array(dims)
922
923    if (dims < 1).any():
924      raise ValueError("All input partition dims must be >= 1.")
925
926    # No partitioning, so don't perform further checks.
927    if dims.prod() == 1:
928      return
929
930    if dims.prod() != self._device_assignment.num_cores_per_replica:
931      raise ValueError(
932          "The product of each input partition dim should equal to "
933          "num_cores_per_replica. (dim = {}, num_cores_per_replica "
934          "= {})".format(dims, self._device_assignment.num_cores_per_replica))
935    if dims.shape[0] != tensor.shape.ndims:
936      raise ValueError(
937          "Input partition dims must have the same number of dimensions "
938          "as the `Tensor` to be partitioned. (tensor shape = {}, input "
939          "partition dims = {}).".format(tensor.shape.as_list(), dims))
940
941    tensor.shape.assert_is_fully_defined()
942
943  def _check_dims_and_partition_or_replicate_on_host(self, tensor, dims):
944    """Checks dims and partitions or replicates the input tensor.
945
946      The ops inside this function are placed on the host side.
947
948    Args:
949      tensor: The input tensor which will be partitioned or replicated.
950      dims: A list of integer describes how to partition the input tensor.
951
952    Returns:
953      An iterator of `Tensor`s or a list of partitioned tensors.
954    """
955    self._check_input_partition_dims(tensor, dims)
956    return partition_or_replicate_on_host(tensor, dims)
957