xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/partitioned_variables.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Helper functions for creating partitioned variables.
17
18This is a convenient abstraction to partition a large variable across
19multiple smaller variables that can be assigned to different devices.
20
21The full variable can be reconstructed by concatenating the smaller variables.
22Using partitioned variables instead of a single variable is mostly a
23performance choice.  It however also has an impact on:
24
251. Random initialization, as the random number generator is called once per
26   slice
272. Updates, as they happen in parallel across slices
28
29A key design goal is to allow a different graph to repartition a variable
30with the same name but different slicings, including possibly no partitions.
31
32TODO(touts): If an initializer provides a seed, the seed must be changed
33deterministically for each slice, maybe by adding one to it, otherwise each
34slice will use the same values.  Maybe this can be done by passing the
35slice offsets to the initializer functions.
36
37Typical usage:
38
39```python
40# Create a list of partitioned variables with:
41vs = create_partitioned_variables(
42    <shape>, <slicing>, <initializer>, name=<optional-name>)
43
44# Pass the list as inputs to embedding_lookup for sharded, parallel lookup:
45y = embedding_lookup(vs, ids, partition_strategy="div")
46
47# Or fetch the variables in parallel to speed up large matmuls:
48z = matmul(x, concat(slice_dim, vs))
49```
50"""
51import math
52
53from tensorflow.python.framework import dtypes
54from tensorflow.python.framework import tensor_shape
55from tensorflow.python.ops import variable_scope
56from tensorflow.python.util import deprecation
57from tensorflow.python.util.tf_export import tf_export
58
59__all__ = [
60    "create_partitioned_variables",
61    "variable_axis_size_partitioner",
62    "min_max_variable_partitioner",
63    "fixed_size_partitioner",
64]
65
66
67@tf_export(v1=["variable_axis_size_partitioner"])
68def variable_axis_size_partitioner(
69    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
70  """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
71
72  This partitioner will shard a Variable along one axis, attempting to keep
73  the maximum shard size below `max_shard_bytes`.  In practice, this is not
74  always possible when sharding along only one axis.  When this happens,
75  this axis is sharded as much as possible (i.e., every dimension becomes
76  a separate shard).
77
78  If the partitioner hits the `max_shards` limit, then each shard may end up
79  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
80  limit on the number of shards is enforced.
81
82  One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
83  `64MB`, to keep below the protobuf byte limit.
84
85  Args:
86    max_shard_bytes: The maximum size any given shard is allowed to be.
87    axis: The axis to partition along.  Default: outermost axis.
88    bytes_per_string_element: If the `Variable` is of type string, this provides
89      an estimate of how large each scalar in the `Variable` is.
90    max_shards: The maximum number of shards in int created taking precedence
91      over `max_shard_bytes`.
92
93  Returns:
94    A partition function usable as the `partitioner` argument to
95    `variable_scope` and `get_variable`.
96
97  Raises:
98    ValueError: If any of the byte counts are non-positive.
99  """
100  if max_shard_bytes < 1 or bytes_per_string_element < 1:
101    raise ValueError(
102        "Both max_shard_bytes and bytes_per_string_element must be positive. "
103        f"Currently, max_shard_bytes is {max_shard_bytes} and"
104        f"bytes_per_string_element is {bytes_per_string_element}")
105  if max_shards and max_shards < 1:
106    raise ValueError(
107        "max_shards must be positive.")
108
109  def _partitioner(shape, dtype):
110    """Partitioner that partitions shards to have max_shard_bytes total size.
111
112    Args:
113      shape: A `TensorShape`.
114      dtype: A `DType`.
115
116    Returns:
117      A tuple representing how much to slice each axis in shape.
118
119    Raises:
120      ValueError: If shape is not a fully defined `TensorShape` or dtype is not
121        a `DType`.
122    """
123    if not isinstance(shape, tensor_shape.TensorShape):
124      raise ValueError(f"shape is not a TensorShape: {shape}")
125    if not shape.is_fully_defined():
126      raise ValueError(f"shape is not fully defined: {shape}")
127    if not isinstance(dtype, dtypes.DType):
128      raise ValueError(f"dtype is not a DType: {dtype}")
129
130    if dtype.base_dtype == dtypes.string:
131      element_size = bytes_per_string_element
132    else:
133      element_size = dtype.size
134
135    partitions = [1] * shape.ndims
136    bytes_per_slice = 1.0 * (
137        shape.num_elements() / shape.dims[axis].value) * element_size
138    # How many slices can we fit on one shard of size at most max_shard_bytes?
139    # At least one slice is required.
140    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
141    # How many shards do we need for axis given that each shard fits
142    # slices_per_shard slices from a total of shape[axis] slices?
143    axis_shards = int(math.ceil(
144        1.0 * shape.dims[axis].value / slices_per_shard))
145    if max_shards:
146      axis_shards = min(max_shards, axis_shards)
147
148    partitions[axis] = axis_shards
149
150    return partitions
151
152  return _partitioner
153
154
155@tf_export(v1=["min_max_variable_partitioner"])
156def min_max_variable_partitioner(max_partitions=1, axis=0,
157                                 min_slice_size=256 << 10,
158                                 bytes_per_string_element=16):
159  """Partitioner to allocate minimum size per slice.
160
161  Returns a partitioner that partitions the variable of given shape and dtype
162  such that each partition has a minimum of `min_slice_size` slice of the
163  variable. The maximum number of such partitions (upper bound) is given by
164  `max_partitions`.
165
166  Args:
167    max_partitions: Upper bound on the number of partitions. Defaults to 1.
168    axis: Axis along which to partition the variable. Defaults to 0.
169    min_slice_size: Minimum size of the variable slice per partition. Defaults
170      to 256K.
171    bytes_per_string_element: If the `Variable` is of type string, this provides
172      an estimate of how large each scalar in the `Variable` is.
173
174  Returns:
175    A partition function usable as the `partitioner` argument to
176    `variable_scope` and `get_variable`.
177
178  """
179  def _partitioner(shape, dtype):
180    """Partitioner that partitions list for a variable of given shape and type.
181
182    Ex: Consider partitioning a variable of type float32 with
183      shape=[1024, 1024].
184      If `max_partitions` >= 16, this function would return
185        [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
186      If `max_partitions` < 16, this function would return
187        [`max_partitions`, 1].
188
189    Args:
190      shape: Shape of the variable.
191      dtype: Type of the variable.
192
193    Returns:
194      List of partitions for each axis (currently only one axis can be
195      partitioned).
196
197    Raises:
198      ValueError: If axis to partition along does not exist for the variable.
199    """
200    if axis >= len(shape):
201      raise ValueError(
202          f"Cannot partition variable along axis {axis} when shape is "
203          f"only {shape}")
204    if dtype.base_dtype == dtypes.string:
205      bytes_per_element = bytes_per_string_element
206    else:
207      bytes_per_element = dtype.size
208    total_size_bytes = shape.num_elements() * bytes_per_element
209    partitions = total_size_bytes / min_slice_size
210    partitions_list = [1] * len(shape)
211    # We can not partition the variable beyond what its shape or
212    # `max_partitions` allows.
213    partitions_list[axis] = max(1, min(shape.dims[axis].value,
214                                       max_partitions,
215                                       int(math.ceil(partitions))))
216    return partitions_list
217  return _partitioner
218
219
220@tf_export(v1=["fixed_size_partitioner"])
221def fixed_size_partitioner(num_shards, axis=0):
222  """Partitioner to specify a fixed number of shards along given axis.
223
224  @compatibility(TF2)
225  This API is deprecated in TF2. In TF2, partitioner is no longer part of
226  the variable declaration via `tf.Variable`.
227  [ParameterServer Training]
228  (https://www.tensorflow.org/tutorials/distribute/parameter_server_training)
229  handles partitioning of variables. The corresponding TF2 partitioner class of
230  `fixed_size_partitioner` is
231  `tf.distribute.experimental.partitioners.FixedShardsPartitioner`.
232
233  Check the [migration guide]
234  (https://www.tensorflow.org/guide/migrate#2_use_python_objects_to_track_variables_and_losses)
235  on the differences in treatment of variables and losses between TF1 and TF2.
236
237  Before:
238
239    ```
240    x = tf.compat.v1.get_variable(
241      "x", shape=(2,), partitioner=tf.compat.v1.fixed_size_partitioner(2)
242    )
243    ```
244  After:
245
246    ```
247    partitioner = (
248        tf.distribute.experimental.partitioners.FixedShardsPartitioner(
249            num_shards=2)
250    )
251    strategy = tf.distribute.experimental.ParameterServerStrategy(
252                   cluster_resolver=cluster_resolver,
253                   variable_partitioner=partitioner)
254
255    with strategy.scope():
256      x = tf.Variable([1.0, 2.0])
257    ```
258  @end_compatibility
259
260  Args:
261    num_shards: `int`, number of shards to partition variable.
262    axis: `int`, axis to partition on.
263
264  Returns:
265    A partition function usable as the `partitioner` argument to
266    `variable_scope` and `get_variable`.
267  """
268  def _partitioner(shape, **unused_args):
269    partitions_list = [1] * len(shape)
270    partitions_list[axis] = min(num_shards, shape.dims[axis].value)
271    return partitions_list
272  return _partitioner
273
274
275@tf_export(v1=["create_partitioned_variables"])
276@deprecation.deprecated(
277    date=None,
278    instructions="Use `tf.get_variable` with a partitioner set.")
279def create_partitioned_variables(
280    shape, slicing, initializer, dtype=dtypes.float32,
281    trainable=True, collections=None, name=None, reuse=None):
282  """Create a list of partitioned variables according to the given `slicing`.
283
284  Currently only one dimension of the full variable can be sliced, and the
285  full variable can be reconstructed by the concatenation of the returned
286  list along that dimension.
287
288  Args:
289    shape: List of integers.  The shape of the full variable.
290    slicing: List of integers.  How to partition the variable.
291      Must be of the same length as `shape`.  Each value
292      indicate how many slices to create in the corresponding
293      dimension.  Presently only one of the values can be more than 1;
294      that is, the variable can only be sliced along one dimension.
295
296      For convenience, The requested number of partitions does not have to
297      divide the corresponding dimension evenly.  If it does not, the
298      shapes of the partitions are incremented by 1 starting from partition
299      0 until all slack is absorbed.  The adjustment rules may change in the
300      future, but as you can save/restore these variables with different
301      slicing specifications this should not be a problem.
302    initializer: A `Tensor` of shape `shape` or a variable initializer
303      function.  If a function, it will be called once for each slice,
304      passing the shape and data type of the slice as parameters.  The
305      function must return a tensor with the same shape as the slice.
306    dtype: Type of the variables. Ignored if `initializer` is a `Tensor`.
307    trainable: If True also add all the variables to the graph collection
308      `GraphKeys.TRAINABLE_VARIABLES`.
309    collections: List of graph collections keys to add the variables to.
310      Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
311    name: Optional name for the full variable.  Defaults to
312      `"PartitionedVariable"` and gets uniquified automatically.
313    reuse: Boolean or `None`; if `True` and name is set, it would reuse
314      previously created variables. if `False` it will create new variables.
315      if `None`, it would inherit the parent scope reuse.
316
317  Returns:
318    A list of Variables corresponding to the slicing.
319
320  Raises:
321    ValueError: If any of the arguments is malformed.
322  """
323  if len(shape) != len(slicing):
324    raise ValueError(
325        "The 'shape' and 'slicing' of a partitioned Variable "
326        f"must have the length: shape: {shape}, slicing: {slicing}")
327  if len(shape) < 1:
328    raise ValueError("A partitioned Variable must have rank at least 1: "
329                     f"shape: {shape}")
330
331  # Legacy: we are provided the slicing directly, so just pass it to
332  # the partitioner.
333  partitioner = lambda **unused_kwargs: slicing
334
335  with variable_scope.variable_scope(
336      name, "PartitionedVariable", reuse=reuse):
337    # pylint: disable=protected-access
338    partitioned_var = variable_scope._get_partitioned_variable(
339        name=None,
340        shape=shape,
341        dtype=dtype,
342        initializer=initializer,
343        trainable=trainable,
344        partitioner=partitioner,
345        collections=collections)
346    return list(partitioned_var)
347    # pylint: enable=protected-access
348