xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/slot_creator.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Standard functions for creating slots.
17
18A slot is a `Variable` created with the same first m-dimension as a primary
19variable or `Tensor`. A slot is always scoped in the namespace of the primary
20object and typically has the same device and type.
21
22Slots are typically used as accumulators to track values associated with
23the primary object:
24
25```python
26# Optimizers can create a slot for each variable to track accumulators
27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
28for var in vs:
29  apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
30
31# Slots can also be used for moving averages
32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
34```
35"""
36# pylint: disable=g-bad-name
37
38from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
39from tensorflow.python.distribute import distribution_strategy_context
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import init_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops import variable_scope
44from tensorflow.python.ops import variables
45
46
47def _create_slot_var(primary,
48                     val,
49                     scope,
50                     validate_shape,
51                     shape,
52                     dtype,
53                     *,
54                     copy_xla_sharding=False):
55  """Helper function for creating a slot variable."""
56
57  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
58  # scope.
59  current_partitioner = variable_scope.get_variable_scope().partitioner
60  variable_scope.get_variable_scope().set_partitioner(None)
61  # When init from val instead of callable initializer, the shape is expected to
62  # be None, not <unknown> or any fully defined shape.
63  shape = shape if callable(val) else None
64  if resource_variable_ops.is_resource_variable(primary):
65    use_resource = True
66  elif isinstance(primary, variables.RefVariable):
67    use_resource = False
68  else:
69    use_resource = None
70  slot = variable_scope.get_variable(
71      scope,
72      initializer=val,
73      trainable=False,
74      use_resource=use_resource,
75      shape=shape,
76      dtype=dtype,
77      validate_shape=validate_shape)
78  variable_scope.get_variable_scope().set_partitioner(current_partitioner)
79
80  # pylint: disable=protected-access
81  if isinstance(primary, variables.Variable) and primary._save_slice_info:
82    # Primary is a partitioned variable, so we need to also indicate that
83    # the slot is a partitioned variable.  Slots have the same partitioning
84    # as their primaries.
85    # For examples when using AdamOptimizer in linear model, slot.name
86    # here can be "linear//weights/Adam:0", while primary.op.name is
87    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
88    # remove "'linear//weight' + '/'" and ':0'.
89    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
90    slice_info = primary._save_slice_info
91    # support slot's shape not same as primary's shape
92    # example: primary's shape = [10, 20, 30], slot's shape =
93    # None, [], [10], [10, 20] or [10, 20, 30] is allowed
94    # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
95    # slot's shape = [], don't set slot's slice_info
96    # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
97    n = slot.shape.ndims
98    if n is None or n > 0:
99      slot._set_save_slice_info(
100          variables.Variable.SaveSliceInfo(
101              slice_info.full_name + "/" + real_slot_name,
102              slice_info.full_shape[:n], slice_info.var_offset[:n],
103              slice_info.var_shape[:n]))
104  # pylint: enable=protected-access
105
106  # Copy XLA sharding attributes from the primary if the slot variable has the
107  # same rank as the primary.
108  def _has_same_rank(primary_shape, slot_shape):
109    return (primary_shape.rank is not None and slot_shape.rank is not None and
110            primary_shape.rank == slot_shape.rank)
111
112  if copy_xla_sharding and _has_same_rank(primary.shape, slot.shape):
113    slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
114  return slot
115
116
117def create_slot(primary,
118                val,
119                name,
120                colocate_with_primary=True,
121                *,
122                copy_xla_sharding=False):
123  """Create a slot initialized to the given value.
124
125  The type of the slot is determined by the given value.
126
127  Args:
128    primary: The primary `Variable` or `Tensor`.
129    val: A `Tensor` specifying the initial value of the slot.
130    name: Name to use for the slot variable.
131    colocate_with_primary: Boolean.  If True the slot is located
132      on the same device as `primary`.
133    copy_xla_sharding: Boolean. If True also copies XLA sharding
134      from primary.
135
136  Returns:
137    A `Variable` object.
138  """
139  # Scope the slot name in the namespace of the primary variable.
140  # Set primary's name + '/' + name as default name, so the scope name of
141  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
142  # and the same name has been previously used, the scope name will add '_N'
143  # as suffix for unique identifications.
144  validate_shape = val.get_shape().is_fully_defined()
145  if isinstance(primary, variables.Variable):
146    prefix = primary._shared_name  # pylint: disable=protected-access
147  else:
148    prefix = primary.op.name
149  with variable_scope.variable_scope(None, prefix + "/" + name):
150    if colocate_with_primary:
151      distribution_strategy = distribution_strategy_context.get_strategy()
152      with distribution_strategy.extended.colocate_vars_with(primary):
153        return _create_slot_var(
154            primary,
155            val,
156            "",
157            validate_shape,
158            None,
159            None,
160            copy_xla_sharding=copy_xla_sharding)
161    else:
162      return _create_slot_var(
163          primary,
164          val,
165          "",
166          validate_shape,
167          None,
168          None,
169          copy_xla_sharding=copy_xla_sharding)
170
171
172def create_slot_with_initializer(primary,
173                                 initializer,
174                                 shape,
175                                 dtype,
176                                 name,
177                                 colocate_with_primary=True,
178                                 *,
179                                 copy_xla_sharding=False):
180  """Creates a slot initialized using an `Initializer`.
181
182  The type of the slot is determined by the given value.
183
184  Args:
185    primary: The primary `Variable` or `Tensor`.
186    initializer: An `Initializer`.  The initial value of the slot.
187    shape: Shape of the initial value of the slot.
188    dtype: Type of the value of the slot.
189    name: Name to use for the slot variable.
190    colocate_with_primary: Boolean.  If True the slot is located
191      on the same device as `primary`.
192    copy_xla_sharding: Boolean. If True also copies XLA sharding
193      from primary.
194
195  Returns:
196    A `Variable` object.
197  """
198  # Scope the slot name in the namespace of the primary variable.
199  # Set "primary.op.name + '/' + name" as default name, so the scope name of
200  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
201  # and the same name has been previously used, the scope name will add '_N'
202  # as suffix for unique identifications.
203  validate_shape = shape.is_fully_defined()
204  if isinstance(primary, variables.Variable):
205    prefix = primary._shared_name  # pylint: disable=protected-access
206  else:
207    prefix = primary.op.name
208  with variable_scope.variable_scope(None, prefix + "/" + name):
209    if colocate_with_primary:
210      distribution_strategy = distribution_strategy_context.get_strategy()
211      with distribution_strategy.extended.colocate_vars_with(primary):
212        return _create_slot_var(
213            primary,
214            initializer,
215            "",
216            validate_shape,
217            shape,
218            dtype,
219            copy_xla_sharding=copy_xla_sharding)
220    else:
221      return _create_slot_var(
222          primary,
223          initializer,
224          "",
225          validate_shape,
226          shape,
227          dtype,
228          copy_xla_sharding=copy_xla_sharding)
229
230
231def create_zeros_slot(primary,
232                      name,
233                      dtype=None,
234                      colocate_with_primary=True,
235                      *,
236                      copy_xla_sharding=False):
237  """Create a slot initialized to 0 with same shape as the primary object.
238
239  Args:
240    primary: The primary `Variable` or `Tensor`.
241    name: Name to use for the slot variable.
242    dtype: Type of the slot variable.  Defaults to the type of `primary`.
243    colocate_with_primary: Boolean.  If True the slot is located
244      on the same device as `primary`.
245    copy_xla_sharding: Boolean. If True also copies XLA sharding
246      from primary.
247
248  Returns:
249    A `Variable` object.
250  """
251  if dtype is None:
252    dtype = primary.dtype
253  slot_shape = primary.get_shape()
254  if slot_shape.is_fully_defined():
255    initializer = init_ops.zeros_initializer()
256    return create_slot_with_initializer(
257        primary,
258        initializer,
259        slot_shape,
260        dtype,
261        name,
262        colocate_with_primary=colocate_with_primary,
263        copy_xla_sharding=copy_xla_sharding)
264  else:
265    if isinstance(primary, variables.Variable):
266      slot_shape = array_ops.shape(primary.initialized_value())
267    else:
268      slot_shape = array_ops.shape(primary)
269    val = array_ops.zeros(slot_shape, dtype=dtype)
270    return create_slot(
271        primary,
272        val,
273        name,
274        colocate_with_primary=colocate_with_primary,
275        copy_xla_sharding=copy_xla_sharding)
276