xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_sharding.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Helper library for sharding during TPU compilation."""
16
17
18from tensorflow.python.framework import tensor_shape
19
20_DEFAULT_NUMBER_OF_SHARDS = 1
21_DEFAULT_SHARD_DIMENSION = 0
22
23
24# TODO(b/36777903) change other parts of tpu.py to use this class.
25class ShardingPolicy(object):
26  """An object use to hold the sharding policy for a Tensor."""
27
28  def __init__(self):
29    self._number_of_shards = None
30    self._number_of_partitions = 1
31    self._shard_dimension = None
32    self._frozen = False
33
34  def __str__(self):
35    if self.number_of_shards is None or self.shard_dimension is None:
36      return "ShardingPolicy(unset)"
37    else:
38      return ("ShardingPolicy(%d shards dimension %d)" %
39              (self.number_of_shards, self.shard_dimension))
40
41  def _fill_default_values(self):
42    if self._number_of_shards is None:
43      self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
44    if self._shard_dimension is None:
45      self._shard_dimension = tensor_shape.as_dimension(
46          _DEFAULT_SHARD_DIMENSION)
47
48  def freeze(self):
49    """Prevents further modification to the sharding policy.
50
51    Any values that have not been set when freeze is called are set to
52    defaults. If the ShardingPolicy is already frozen, this is a NoOp.
53    """
54    if not self._frozen:
55      self._fill_default_values()
56      self._frozen = True
57
58  @property
59  def number_of_shards(self):
60    """Returns the number of shards in the policy or None if unspecified."""
61    return self._number_of_shards
62
63  def set_number_of_shards(self, number_of_shards):
64    """Sets the number of shards for the current policy.
65
66    If the policy has been frozen then number_of_shards must match the
67    existing setting.
68
69    Args:
70      number_of_shards: The number of shards to use in the policy.
71
72    Raises:
73      ValueError: If the policy has been frozen and number_of_shards
74        differs from the frozen value; or number_of_shards <= 0.
75    """
76    if self._frozen:
77      if self._number_of_shards != number_of_shards:
78        raise ValueError(
79            f"Can't set sharding policy to use {number_of_shards} shards since "
80            f"it has been frozen to use {self._number_of_shards}")
81    else:
82      if number_of_shards > 0:
83        self._number_of_shards = number_of_shards
84      else:
85        raise ValueError(
86            f"Can't set sharding policy to use {number_of_shards} shards; "
87            "value must be > 0")
88
89  @property
90  def number_of_partitions(self):
91    """Returns the number of partitions of the policy or None if unspecified."""
92    return self._number_of_partitions
93
94  def set_number_of_partitions(self, number_of_partitions):
95    """Sets the number of partitions for the current policy.
96
97    If the policy has been frozen then shard_dimension must match the
98    existing setting.
99
100    Args:
101      number_of_partitions: The number of partitions to use in the policy.
102
103    Raises:
104      ValueError: If the policy has been frozen and shard_dimension
105        differs from the frozen value.
106    """
107    if self._frozen:
108      if self._number_of_partitions != number_of_partitions:
109        raise ValueError(
110            f"Can't set number_of_partitions to {number_of_partitions} since "
111            f"it has been frozen to use {self._number_of_partitions}.")
112    else:
113      self._number_of_partitions = number_of_partitions
114
115  @property
116  def shard_dimension(self):
117    """Returns the shard dimension of the policy or None if unspecified."""
118    return self._shard_dimension
119
120  def set_shard_dimension(self, shard_dimension):
121    """Sets the shard dimension for the current policy.
122
123    If the policy has been frozen then shard_dimension must match the
124    existing setting.
125
126    Args:
127      shard_dimension: The shard dimension to use in the policy.
128
129    Raises:
130      ValueError: If the policy has been frozen and shard_dimension
131        differs from the frozen value, or shard_dimension can't be
132        interpreted as a Dimension.
133    """
134    if self._frozen:
135      if self._shard_dimension != shard_dimension:
136        raise ValueError(
137            "Can't set shard dimension to %d since it has been frozen to "
138            "use %d." % (shard_dimension, self._shard_dimension))
139    else:
140      self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
141
142  def merge(self, other):
143    """Merges the policy of another policy into the current policy.
144
145    Args:
146      other: The policy to merge into this one.
147
148    Raises:
149      ValueError: If this policy has been frozen and the merge conflicts with
150      the frozen policy.
151    """
152    if other.number_of_shards is not None:
153      self.set_number_of_shards(other.number_of_shards)
154    if other.shard_dimension is not None:
155      self.set_shard_dimension(other.shard_dimension)
156
157  def get_unpartitioned_shape(self, shape):
158    """Returns the shape of an unpartitioned Tensor.
159
160    When given the shape of a 'sharded-size' Tensor, returns the shape
161    of the full shape of its unpartitioned Tensor.
162
163    Args:
164      shape: The shape of the sharded Tensor.
165
166    Returns:
167      The shape of the unpartitioned version of the Tensor.
168
169    Raises:
170      ValueError: if shape has unknown sharded dimension
171    """
172    shape = tensor_shape.as_shape(shape)
173    dims = shape.as_list()
174    if (self._shard_dimension is None or self._number_of_partitions is None or
175        not dims):
176      return None
177    if dims[self._shard_dimension] is None:
178      raise ValueError(f"Shape {shape.as_list()} must have a fixed size for "
179                       f"dimension {self._shard_dimension} that is known. ")
180    if self._number_of_partitions > 1:
181      dims[self._shard_dimension] *= self._number_of_partitions
182    return tensor_shape.as_shape(dims)
183
184  def get_sharded_shape(self, shape, shard_index=None):
185    """Returns the shape of a shard of a full Tensor.
186
187    When given the shape of a 'full-size' Tensor, returns the shape of
188    the sub-Tensor after it has been sharded. Freezes the policy if it
189    has not yet been frozen.
190
191    Args:
192      shape: The shape of the full-size Tensor to be sharded.
193      shard_index: The index of the shard whose shape should be returned.
194        shard_index can be None for sharding policies that use the same shape
195        for every shard.
196
197    Returns:
198      The shape of the sharded version of the Tensor.
199
200    Raises:
201      ValueError: If shard_index is None when shards are of different
202        shapes; or shard_index is not None and
203        !(0<=shard_index<number_of_shards); or shape does not have at
204        least self.shard_dimension+1 dimensions; or the value of
205        shape's shard dimension is not a multiple of
206        self.number_of_shards
207    """
208    if self._shard_dimension is None or self._number_of_shards is None:
209      # Don't raise an error if the config is unset.
210      return None
211    if shard_index is not None:
212      if shard_index < 0 or shard_index >= self.number_of_shards:
213        raise ValueError(
214            f"Requested shard_index {shard_index}, but shard_index must be in "
215            f"[0,{self._number_of_shards}).")
216    shape = tensor_shape.as_shape(shape)
217    if self._number_of_shards == 1:
218      # Don't do anything when there's only one shard.
219      return shape
220    ndims = shape.ndims
221    if ndims is None:
222      raise ValueError(f"Shape {shape} must be a known shape.")
223    if ndims <= self._shard_dimension:
224      raise ValueError(
225          f"Shape {shape.as_list()} does not contain shard_dimension "
226          f"{self._shard_dimension}")
227    dims = shape.as_list()
228    if dims[self._shard_dimension] is None:
229      raise ValueError(
230          f"Shape {shape.as_list()} must have a fixed size for dimension "
231          f"{self._shard_dimension} that is known at construction time.")
232    if (dims[self._shard_dimension] % self._number_of_shards) != 0:
233      raise ValueError(
234          f"Shape {shape.as_list()} cannot be sharded {self._number_of_shards} "
235          f"ways along dimension {self._shard_dimension}")
236    dims[self._shard_dimension] //= self._number_of_shards
237    return tensor_shape.TensorShape(dims)
238
239  def _unshard_shape(self, shape):
240    """Return the unsharded shape that would generate a given sharded shape.
241
242    Args:
243      shape: the sharded shape to unshard
244
245    Returns:
246      The unsharded shape.
247
248    Raises:
249      ValueError: if shape is unknown or does not contain
250        self.shard_dimension
251      TypeError: if shape is not convertible to a TensorShape
252    """
253    shape = tensor_shape.as_shape(shape)
254    if self._number_of_shards == 1:
255      # Don't do anything when there's only one shard.
256      return shape
257    ndims = shape.ndims
258    if ndims is None:
259      raise ValueError(f"Shape {shape} must be statically known.")
260    if ndims <= self._shard_dimension:
261      raise ValueError(f"Shape {shape.as_list()} does not contain "
262                       f"shard_dimension {self._shard_dimension}. "
263                       f"Rank is too small.")
264    dims = shape.as_list()
265    dims[self._shard_dimension] *= self._number_of_shards
266    return tensor_shape.TensorShape(dims)
267
268  def get_unsharded_shape(self, shapes):
269    """Returns the shape of an unsharded Tensor given a list of shards.
270
271    When given a list of shapes of shards, returns the shape of the
272    unsharded Tensor that would generate the shards. Sets defaults for the
273    policy if number_of_shards or shard_dimension is None.
274
275    Args:
276      shapes: The shapes of the Tensor shards to be combined.
277
278    Returns:
279      The shape of the unsharded version of the Tensor.
280
281    Raises:
282      ValueError: if shapes is not a list of length
283        self.number_of_shards; or any element of shapes is not a valid
284        shape consistent with the sharding policy; or the list of
285        shapes is not a valid sharding of a full shape.
286      TypeError: if an element of shapes is not convertible to a
287        TensorShape
288    """
289    self._fill_default_values()
290    if len(shapes) != self.number_of_shards:
291      raise ValueError(
292          f"Shapes {shapes} is length {len(shapes)} but must be a list of "
293          f"length number_of_shards={self.number_of_shards}")
294    unsharded_shapes = [self._unshard_shape(s) for s in shapes]
295    for i in range(self.number_of_shards - 1):
296      if not unsharded_shapes[i].is_compatible_with(
297          unsharded_shapes[self.number_of_shards - 1]):
298        raise ValueError(
299            f"Sharded shapes {shapes} are not consistent shards of a full shape "
300            f"sharded {self.number_of_shards} ways along "
301            f"dimension {self.shard_dimension}.")
302    return unsharded_shapes[0]
303