xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/optimizer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base class for optimizers."""
17# pylint: disable=g-bad-name
18
19import abc
20
21from tensorflow.python.distribute import distribute_lib
22from tensorflow.python.distribute import distribute_utils
23from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
24from tensorflow.python.distribute import reduce_util as ds_reduce_util
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import indexed_slices
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import gradients
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import state_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.ops import variables
38from tensorflow.python.trackable import base as trackable
39from tensorflow.python.training import slot_creator
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import tf_export
42
43
44def get_filtered_grad_fn(grad_fn):
45  # `distributed_context.join()` requires that its arguments are parallel
46  # across threads, and in particular that `grads_and_vars` has the same
47  # variables in the same order.
48
49  # When computing gradients in eager mode with multiple threads, you
50  # can get extra variables with a gradient of `None`. This happens when
51  # those variables are accessed in another thread during the gradient
52  # computation. To get a consistent set of variables, we filter out
53  # those with `None` gradients.
54  def filtered_grad_fn(*args, **kwargs):
55    return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None]
56
57  return filtered_grad_fn
58
59
60def _deduplicate_indexed_slices(values, indices):
61  """Sums `values` associated with any non-unique `indices`.
62
63  Args:
64    values: A `Tensor` with rank >= 1.
65    indices: A one-dimensional integer `Tensor`, indexing into the first
66      dimension of `values` (as in an IndexedSlices object).
67  Returns:
68    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
69    de-duplicated version of `indices` and `summed_values` contains the sum of
70    `values` slices associated with each unique index.
71  """
72  unique_indices, new_index_positions = array_ops.unique(indices)
73  summed_values = math_ops.unsorted_segment_sum(
74      values, new_index_positions,
75      array_ops.shape(unique_indices)[0])
76  return (summed_values, unique_indices)
77
78
79def _var_key(var):
80  """Returns slot key for `var`."""
81  # pylint: disable=protected-access
82  if hasattr(var, "_distributed_container"):
83    var = var._distributed_container()
84  if (distribute_utils.is_distributed_variable(var) and
85      not ops.executing_eagerly_outside_functions()):
86    return (var.graph, var._shared_name)
87  if hasattr(var, "op"):
88    return (var.op.graph, var.op.name)
89  return var._unique_id
90  # pylint: enable=protected-access
91
92
93class _OptimizableVariable(metaclass=abc.ABCMeta):
94  """Interface for abstracting over variables in the optimizers."""
95
96  @abc.abstractmethod
97  def target(self):
98    """Returns the optimization target for this variable."""
99    raise NotImplementedError("Calling an abstract method.")
100
101  @abc.abstractmethod
102  def update_op(self, optimizer, g):
103    """Returns the update ops for updating the variable."""
104    raise NotImplementedError("Calling an abstract method.")
105
106
107class _RefVariableProcessor(_OptimizableVariable):
108  """Processor for Variable."""
109
110  def __init__(self, v):
111    self._v = v
112
113  def __str__(self):
114    return "<_RefVariableProcessor(%s)>" % self._v
115
116  def target(self):
117    return self._v._ref()  # pylint: disable=protected-access
118
119  def update_op(self, optimizer, g):
120    if isinstance(g, ops.Tensor):
121      update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
122      if self._v.constraint is not None:
123        with ops.control_dependencies([update_op]):
124          return self._v.assign(self._v.constraint(self._v))
125      else:
126        return update_op
127    else:
128      assert isinstance(g, indexed_slices.IndexedSlices), (
129          "Gradient ", g, " is neither a tensor nor IndexedSlices.")
130      if self._v.constraint is not None:
131        raise RuntimeError(
132            "Cannot use a constraint function on a sparse variable.")
133      # pylint: disable=protected-access
134      return optimizer._apply_sparse_duplicate_indices(g, self._v)
135
136
137class _DenseReadResourceVariableProcessor(_OptimizableVariable):
138  """Processor for dense ResourceVariables."""
139
140  def __init__(self, v):
141    self._v = v
142
143  def target(self):
144    return self._v
145
146  def update_op(self, optimizer, g):
147    # pylint: disable=protected-access
148    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
149    if self._v.constraint is not None:
150      with ops.control_dependencies([update_op]):
151        return self._v.assign(self._v.constraint(self._v))
152    else:
153      return update_op
154
155
156class _DenseResourceVariableProcessor(_OptimizableVariable):
157  """Processor for dense ResourceVariables."""
158
159  def __init__(self, v):
160    self._v = v
161
162  def target(self):
163    return self._v
164
165  def update_op(self, optimizer, g):
166    # pylint: disable=protected-access
167    if isinstance(g, indexed_slices.IndexedSlices):
168      if self._v.constraint is not None:
169        raise RuntimeError(
170            "Cannot use a constraint function on a sparse variable.")
171      return optimizer._resource_apply_sparse_duplicate_indices(
172          g.values, self._v, g.indices)
173    update_op = optimizer._resource_apply_dense(g, self._v)
174    if self._v.constraint is not None:
175      with ops.control_dependencies([update_op]):
176        return self._v.assign(self._v.constraint(self._v))
177    else:
178      return update_op
179
180
181class _TensorProcessor(_OptimizableVariable):
182  """Processor for ordinary Tensors.
183
184  Even though a Tensor can't really be updated, sometimes it is useful to
185  compute the gradients with respect to a Tensor using the optimizer. Updating
186  the Tensor is, of course, unsupported.
187  """
188
189  def __init__(self, v):
190    self._v = v
191
192  def target(self):
193    return self._v
194
195  def update_op(self, optimizer, g):
196    raise NotImplementedError("Trying to update a Tensor ", self._v)
197
198
199def _get_processor(v):
200  """The processor of v."""
201  if context.executing_eagerly():
202    if isinstance(v, ops.Tensor):
203      return _TensorProcessor(v)
204    else:
205      return _DenseResourceVariableProcessor(v)
206  if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode:  # pylint: disable=protected-access
207    # True if and only if `v` was initialized eagerly.
208    return _DenseResourceVariableProcessor(v)
209  if v.op.type == "VarHandleOp":
210    return _DenseResourceVariableProcessor(v)
211  if isinstance(v, variables.Variable):
212    return _RefVariableProcessor(v)
213  if isinstance(v, ops.Tensor):
214    return _TensorProcessor(v)
215  raise NotImplementedError("Trying to optimize unsupported type ", v)
216
217
218@tf_export(v1=["train.Optimizer"])
219class Optimizer(
220    # Optimizers inherit from Trackable rather than AutoTrackable
221    # since they do most of their dependency management themselves (slot
222    # variables are special-cased, and non-slot variables are keyed to graphs).
223    trackable.Trackable):
224  """Base class for optimizers.
225
226  This class defines the API to add Ops to train a model.  You never use this
227  class directly, but instead instantiate one of its subclasses such as
228  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
229
230  ### Usage
231
232  ```python
233  # Create an optimizer with the desired parameters.
234  opt = GradientDescentOptimizer(learning_rate=0.1)
235  # Add Ops to the graph to minimize a cost by updating a list of variables.
236  # "cost" is a Tensor, and the list of variables contains tf.Variable
237  # objects.
238  opt_op = opt.minimize(cost, var_list=<list of variables>)
239  ```
240
241  In the training program you will just have to run the returned Op.
242
243  ```python
244  # Execute opt_op to do one step of training:
245  opt_op.run()
246  ```
247
248  ### Processing gradients before applying them.
249
250  Calling `minimize()` takes care of both computing the gradients and
251  applying them to the variables.  If you want to process the gradients
252  before applying them you can instead use the optimizer in three steps:
253
254  1.  Compute the gradients with `compute_gradients()`.
255  2.  Process the gradients as you wish.
256  3.  Apply the processed gradients with `apply_gradients()`.
257
258  Example:
259
260  ```python
261  # Create an optimizer.
262  opt = GradientDescentOptimizer(learning_rate=0.1)
263
264  # Compute the gradients for a list of variables.
265  grads_and_vars = opt.compute_gradients(loss, <list of variables>)
266
267  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
268  # need to the 'gradient' part, for example cap them, etc.
269  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
270
271  # Ask the optimizer to apply the capped gradients.
272  opt.apply_gradients(capped_grads_and_vars)
273  ```
274
275  ### Gating Gradients
276
277  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
278  argument that controls the degree of parallelism during the application of
279  the gradients.
280
281  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
282
283  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
284  the maximum parallelism in execution, at the cost of some non-reproducibility
285  in the results.  For example the two gradients of `matmul` depend on the input
286  values: With `GATE_NONE` one of the gradients could be applied to one of the
287  inputs _before_ the other gradient is computed resulting in non-reproducible
288  results.
289
290  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
291  they are used.  This prevents race conditions for Ops that generate gradients
292  for multiple inputs where the gradients depend on the inputs.
293
294  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
295  before any one of them is used.  This provides the least parallelism but can
296  be useful if you want to process all gradients before applying any of them.
297
298  ### Slots
299
300  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
301  allocate and manage additional variables associated with the variables to
302  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
303  optimizer for the names of the slots that it uses.  Once you have a slot name
304  you can ask the optimizer for the variable it created to hold the slot value.
305
306  This can be useful if you want to log debug a training algorithm, report stats
307  about the slots, etc.
308
309  @compatibility(TF2)
310  `tf.compat.v1.train.Optimizer` can be used in eager mode and `tf.function`,
311  but it is not recommended. Please use the subclasses of
312  `tf.keras.optimizers.Optimizer` instead in TF2. Please see [Basic training
313  loops](https://www.tensorflow.org/guide/basic_training_loops) or
314  [Writing a training loop from scratch]
315  (https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)
316  for examples.
317
318  If your TF1 code contains a `tf.compat.v1.train.Optimizer` symbol, whether it
319  is used with or without a `tf.estimator.Estimator`, you cannot simply replace
320  that with the corresponding `tf.keras.optimizers.Optimizer`s. To migrate to
321  TF2, it is advised the whole training program used with `Estimator` to be
322  migrated to Keras `Model.fit` based or TF2 custom training loops.
323
324  #### Structural Mapping to Native TF2
325
326  Before:
327
328  ```python
329  sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0)
330  opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
331  opt_op.run(session=session)
332  ```
333
334  After:
335
336  ```python
337  sgd = tf.keras.optimizers.SGD(3.0)
338  sgd.minimize(cost_fn, [var0, var1])
339  ```
340
341  #### How to Map Arguments
342
343  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
344  | :-------------------- | :-------------- | :------------------------- |
345  | `use_locking`         | Not supported   | -                          |
346  | `name`                | `name. `        | -                          |
347
348  #### Before & After Usage Example
349
350  Before:
351
352  >>> g = tf.compat.v1.Graph()
353  >>> with g.as_default():
354  ...   var0 = tf.compat.v1.Variable([1.0, 2.0])
355  ...   var1 = tf.compat.v1.Variable([3.0, 4.0])
356  ...   cost = 5 * var0 + 3 * var1
357  ...   global_step = tf.compat.v1.Variable(
358  ...       tf.compat.v1.zeros([], tf.compat.v1.int64), name='global_step')
359  ...   init_op = tf.compat.v1.initialize_all_variables()
360  ...   sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0)
361  ...   opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
362  >>> session = tf.compat.v1.Session(graph=g)
363  >>> session.run(init_op)
364  >>> opt_op.run(session=session)
365  >>> print(session.run(var0))
366  [-14. -13.]
367
368
369  After:
370  >>> var0 = tf.Variable([1.0, 2.0])
371  >>> var1 = tf.Variable([3.0, 4.0])
372  >>> cost_fn = lambda: 5 * var0 + 3 * var1
373  >>> sgd = tf.keras.optimizers.SGD(3.0)
374  >>> sgd.minimize(cost_fn, [var0, var1])
375  >>> print(var0.numpy())
376  [-14. -13.]
377
378  @end_compatibility
379
380
381  """
382
383  # Values for gate_gradients.
384  GATE_NONE = 0
385  GATE_OP = 1
386  GATE_GRAPH = 2
387
388  def __init__(self, use_locking, name):
389    """Create a new Optimizer.
390
391    This must be called by the constructors of subclasses.
392
393    Args:
394      use_locking: Bool. If True apply use locks to prevent concurrent updates
395        to variables.
396      name: A non-empty string.  The name to use for accumulators created
397        for the optimizer.
398
399    Raises:
400      ValueError: If name is malformed.
401    """
402    if not name:
403      raise ValueError("Must specify the optimizer name")
404    self._use_locking = use_locking
405    self._name = name
406    # Dictionary of slots.
407    #  {slot_name :
408    #      {_var_key(variable_to_train): slot_for_the_variable, ... },
409    #   ... }
410    self._slots = {}
411    self._non_slot_dict = {}
412    # For implementing Trackable. Stores information about how to restore
413    # slot variables which have not yet been created
414    # (trackable._CheckpointPosition objects).
415    #  {slot_name :
416    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
417    #   ... }
418    self._deferred_slot_restorations = {}
419
420    # TODO(isaprykin): When using a DistributionStrategy, and when an
421    # optimizer is created in each replica, it might be dangerous to
422    # rely on some Optimizer methods.  When such methods are called on a
423    # per-replica optimizer, an exception needs to be thrown.  We do
424    # allow creation per-replica optimizers however, because the
425    # compute_gradients()->apply_gradients() sequence is safe.
426
427  def get_name(self):
428    return self._name
429
430  def minimize(self, loss, global_step=None, var_list=None,
431               gate_gradients=GATE_OP, aggregation_method=None,
432               colocate_gradients_with_ops=False, name=None,
433               grad_loss=None):
434    """Add operations to minimize `loss` by updating `var_list`.
435
436    This method simply combines calls `compute_gradients()` and
437    `apply_gradients()`. If you want to process the gradient before applying
438    them call `compute_gradients()` and `apply_gradients()` explicitly instead
439    of using this function.
440
441    Args:
442      loss: A `Tensor` containing the value to minimize.
443      global_step: Optional `Variable` to increment by one after the
444        variables have been updated.
445      var_list: Optional list or tuple of `Variable` objects to update to
446        minimize `loss`.  Defaults to the list of variables collected in
447        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
448      gate_gradients: How to gate the computation of gradients.  Can be
449        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
450      aggregation_method: Specifies the method used to combine gradient terms.
451        Valid values are defined in the class `AggregationMethod`.
452      colocate_gradients_with_ops: If True, try colocating gradients with
453        the corresponding op.
454      name: Optional name for the returned operation.
455      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
456
457    Returns:
458      An Operation that updates the variables in `var_list`.  If `global_step`
459      was not `None`, that operation also increments `global_step`.
460
461    Raises:
462      ValueError: If some of the variables are not `Variable` objects.
463
464    @compatibility(eager)
465    When eager execution is enabled, `loss` should be a Python function that
466    takes no arguments and computes the value to be minimized. Minimization (and
467    gradient computation) is done with respect to the elements of `var_list` if
468    not None, else with respect to any trainable variables created during the
469    execution of the `loss` function. `gate_gradients`, `aggregation_method`,
470    `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
471    execution is enabled.
472    @end_compatibility
473    """
474    grads_and_vars = self.compute_gradients(
475        loss, var_list=var_list, gate_gradients=gate_gradients,
476        aggregation_method=aggregation_method,
477        colocate_gradients_with_ops=colocate_gradients_with_ops,
478        grad_loss=grad_loss)
479
480    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
481    if not vars_with_grad:
482      raise ValueError(
483          "No gradients provided for any variable, check your graph for ops"
484          " that do not support gradients, between variables %s and loss %s." %
485          ([str(v) for _, v in grads_and_vars], loss))
486
487    return self.apply_gradients(grads_and_vars, global_step=global_step,
488                                name=name)
489
490  def compute_gradients(self, loss, var_list=None,
491                        gate_gradients=GATE_OP,
492                        aggregation_method=None,
493                        colocate_gradients_with_ops=False,
494                        grad_loss=None):
495    """Compute gradients of `loss` for the variables in `var_list`.
496
497    This is the first part of `minimize()`.  It returns a list
498    of (gradient, variable) pairs where "gradient" is the gradient
499    for "variable".  Note that "gradient" can be a `Tensor`, an
500    `IndexedSlices`, or `None` if there is no gradient for the
501    given variable.
502
503    @compatibility(TF2)
504    `tf.keras.optimizers.Optimizer` in TF2 does not provide a
505    `compute_gradients` method, and you should use a `tf.GradientTape` to
506    obtain the gradients:
507
508    ```python
509    @tf.function
510    def train step(inputs):
511      batch_data, labels = inputs
512      with tf.GradientTape() as tape:
513        predictions = model(batch_data, training=True)
514        loss = tf.keras.losses.CategoricalCrossentropy(
515            reduction=tf.keras.losses.Reduction.NONE)(labels, predictions)
516      gradients = tape.gradient(loss, model.trainable_variables)
517      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
518    ```
519
520    Args:
521      loss: A Tensor containing the value to minimize or a callable taking
522        no arguments which returns the value to minimize. When eager execution
523        is enabled it must be a callable.
524      var_list: Optional list or tuple of `tf.Variable` to update to minimize
525        `loss`.  Defaults to the list of variables collected in the graph
526        under the key `GraphKeys.TRAINABLE_VARIABLES`.
527      gate_gradients: How to gate the computation of gradients.  Can be
528        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
529      aggregation_method: Specifies the method used to combine gradient terms.
530        Valid values are defined in the class `AggregationMethod`.
531      colocate_gradients_with_ops: If True, try colocating gradients with
532        the corresponding op.
533      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
534
535    Returns:
536      A list of (gradient, variable) pairs. Variable is always present, but
537      gradient can be `None`.
538
539    Raises:
540      TypeError: If `var_list` contains anything else than `Variable` objects.
541      ValueError: If some arguments are invalid.
542      RuntimeError: If called with eager execution enabled and `loss` is
543        not callable.
544
545    @compatibility(eager)
546    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
547    and `colocate_gradients_with_ops` are ignored.
548    @end_compatibility
549    """
550    if callable(loss):
551      with backprop.GradientTape() as tape:
552        if var_list is not None:
553          tape.watch(var_list)
554        loss_value = loss()
555
556        # Scale loss if using a "mean" loss reduction and multiple replicas.
557        # Have to be careful to call distribute_lib.get_loss_reduction()
558        # *after* loss() is evaluated, so we know what loss reduction it uses.
559        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
560        loss_value = self._scale_loss(loss_value)
561
562      if var_list is None:
563        var_list = tape.watched_variables()
564      # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
565      # to be executed.
566      with ops.control_dependencies([loss_value]):
567        grads = tape.gradient(loss_value, var_list, grad_loss)
568      return list(zip(grads, var_list))
569
570    # Non-callable/Tensor loss case
571    if context.executing_eagerly():
572      raise RuntimeError(
573          "`loss` passed to Optimizer.compute_gradients should "
574          "be a function when eager execution is enabled.")
575
576    # Scale loss if using a "mean" loss reduction and multiple replicas.
577    loss = self._scale_loss(loss)
578
579    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
580                              Optimizer.GATE_GRAPH]:
581      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
582                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
583                       gate_gradients)
584    self._assert_valid_dtypes([loss])
585    if grad_loss is not None:
586      self._assert_valid_dtypes([grad_loss])
587    if var_list is None:
588      var_list = (
589          variables.trainable_variables() +
590          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
591    else:
592      var_list = nest.flatten(var_list)
593    # pylint: disable=protected-access
594    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
595    # pylint: enable=protected-access
596    processors = [_get_processor(v) for v in var_list]
597    if not var_list:
598      raise ValueError("No variables to optimize.")
599    var_refs = [p.target() for p in processors]
600    grads = gradients.gradients(
601        loss, var_refs, grad_ys=grad_loss,
602        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
603        aggregation_method=aggregation_method,
604        colocate_gradients_with_ops=colocate_gradients_with_ops)
605    if gate_gradients == Optimizer.GATE_GRAPH:
606      grads = control_flow_ops.tuple(grads)
607    grads_and_vars = list(zip(grads, var_list))
608    self._assert_valid_dtypes(
609        [v for g, v in grads_and_vars
610         if g is not None and v.dtype != dtypes.resource])
611    return grads_and_vars
612
613  @staticmethod
614  def _scale_loss(loss_value):
615    ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
616    if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
617      num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
618      if num_replicas > 1:
619        loss_value *= (1. / num_replicas)
620        ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
621    return loss_value
622
623  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
624    """Apply gradients to variables.
625
626    This is the second part of `minimize()`. It returns an `Operation` that
627    applies gradients.
628
629    @compatibility(TF2)
630    #### How to Map Arguments
631
632    | TF1 Arg Name          | TF2 Arg Name    | Note                       |
633    | :-------------------- | :-------------- | :------------------------- |
634    | `grads_and_vars`      | `grads_and_vars`| -                          |
635    | `global_step`         | Not supported.  | Use `optimizer.iterations` |
636    | `name`                | `name. `        | -                          |
637
638    Args:
639      grads_and_vars: List of (gradient, variable) pairs as returned by
640        `compute_gradients()`.
641      global_step: Optional `Variable` to increment by one after the
642        variables have been updated.
643      name: Optional name for the returned operation.  Default to the
644        name passed to the `Optimizer` constructor.
645
646    Returns:
647      An `Operation` that applies the specified gradients. If `global_step`
648      was not None, that operation also increments `global_step`.
649
650    Raises:
651      TypeError: If `grads_and_vars` is malformed.
652      ValueError: If none of the variables have gradients.
653      RuntimeError: If you should use `_distributed_apply()` instead.
654    """
655    # This is a default implementation of apply_gradients() that can be shared
656    # by most optimizers.  It relies on the subclass implementing the following
657    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
658
659    # TODO(isaprykin): Get rid of `has_strategy()` check by
660    # always calling _distributed_apply(), using the default distribution
661    # as needed.
662    if distribute_ctx.has_strategy():
663      # Handle DistributionStrategy case.
664      if distribute_ctx.in_cross_replica_context():
665        raise RuntimeError("Use `_distributed_apply()` instead of "
666                           "`apply_gradients()` in a cross-replica context.")
667
668      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
669      return distribute_ctx.get_replica_context().merge_call(
670          self._distributed_apply, args=(grads_and_vars, global_step, name))
671
672    # No DistributionStrategy case.
673    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
674    if not grads_and_vars:
675      raise ValueError("No variables provided.")
676    converted_grads_and_vars = []
677    for g, v in grads_and_vars:
678      if g is not None:
679        try:
680          # Convert the grad to Tensor or IndexedSlices if necessary.
681          g = ops.convert_to_tensor_or_indexed_slices(g)
682        except TypeError:
683          raise TypeError(
684              "Gradient must be convertible to a Tensor"
685              " or IndexedSlices, or None: %s" % g)
686        if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)):
687          raise TypeError(
688              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
689      p = _get_processor(v)
690      converted_grads_and_vars.append((g, v, p))
691
692    converted_grads_and_vars = tuple(converted_grads_and_vars)
693    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
694    if not var_list:
695      raise ValueError("No gradients provided for any variable: %s." %
696                       ([str(v) for _, v, _ in converted_grads_and_vars],))
697    with ops.init_scope():
698      self._create_slots(var_list)
699    update_ops = []
700    with ops.name_scope(name, self._name, skip_on_eager=False) as name:
701      self._prepare()
702      for grad, var, processor in converted_grads_and_vars:
703        if grad is None:
704          continue
705        # We colocate all ops created in _apply_dense or _apply_sparse
706        # on the same device as the variable.
707        # TODO(apassos): figure out how to get the variable name here.
708        if (context.executing_eagerly() or
709            resource_variable_ops.is_resource_variable(var)
710            and not var._in_graph_mode):  # pylint: disable=protected-access
711          scope_name = ""
712        else:
713          scope_name = var.op.name
714        with ops.name_scope(
715            "update_" + scope_name,
716            skip_on_eager=False), ops.colocate_with(var):
717          update_ops.append(processor.update_op(self, grad))
718      if global_step is None:
719        apply_updates = self._finish(update_ops, name)
720      else:
721        with ops.control_dependencies([self._finish(update_ops, "update")]):
722          with ops.colocate_with(global_step):
723            if isinstance(
724                global_step, resource_variable_ops.BaseResourceVariable):
725              # TODO(apassos): the implicit read in assign_add is slow; consider
726              # making it less so.
727              apply_updates = resource_variable_ops.assign_add_variable_op(
728                  global_step.handle,
729                  ops.convert_to_tensor(1, dtype=global_step.dtype),
730                  name=name)
731            else:
732              apply_updates = state_ops.assign_add(global_step, 1, name=name)
733
734      if not context.executing_eagerly():
735        if isinstance(apply_updates, ops.Tensor):
736          apply_updates = apply_updates.op
737        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
738        if apply_updates not in train_op:
739          train_op.append(apply_updates)
740
741      return apply_updates
742
743  def _distributed_apply(self,
744                         distribution,
745                         grads_and_vars,
746                         global_step=None,
747                         name=None):
748    """A version of `apply_gradients` for cross-replica context.
749
750    This is a version of `apply_gradients()` for when you are using a
751    `DistributionStrategy` and are in a cross-replica context. If in a
752    replica context, use `apply_gradients()` as normal.
753
754    Args:
755      distribution: A `DistributionStrategy` object.
756      grads_and_vars: List of (gradient, variable) pairs as returned by
757        `compute_gradients()`, and then aggregated across replicas.
758      global_step: Optional (mirrored) `Variable` to increment by one
759        after the variables have been updated.
760      name: Optional name for the returned operation.  Default to the
761        name passed to the `Optimizer` constructor.
762
763    Returns:
764      An `Operation` that applies the specified gradients across all
765      replicas. If `global_step` was not None, that operation also
766      increments `global_step`
767    """
768    reduced_grads = distribution.extended.batch_reduce_to(
769        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
770    var_list = [v for _, v in grads_and_vars]
771    grads_and_vars = zip(reduced_grads, var_list)
772
773    # Note that this is called in a cross-replica context.
774    with ops.init_scope():
775      self._create_slots(var_list)
776
777    def update(v, g):
778      """Apply gradients to a replica variable."""
779      assert v is not None
780
781      try:
782        # Convert the grad to Tensor or IndexedSlices if necessary.
783        g = ops.convert_to_tensor_or_indexed_slices(g)
784      except TypeError:
785        raise TypeError("Gradient must be convertible to a Tensor"
786                        " or IndexedSlices, or None: %s" % g)
787      if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)):
788        raise TypeError(
789            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
790      p = _get_processor(v)
791
792      if context.executing_eagerly() or (
793          resource_variable_ops.is_resource_variable(v) and
794          not v._in_graph_mode):  # pylint: disable=protected-access
795        scope_name = v.name.split(":")[0]
796      else:
797        scope_name = v.op.name
798
799      # device_policy is set because non-mirrored tensors will be read in
800      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
801      # is an example.
802      with ops.name_scope("update_" + scope_name):
803        return p.update_op(self, g)
804
805    with ops.name_scope(name, self._name) as name:
806      self._prepare()
807
808      update_ops = [
809          op
810          for grad, var in grads_and_vars
811          for op in distribution.extended.update(
812              var, update, args=(grad,), group=False)
813      ]
814
815      def finish(self, update_ops):
816        return self._finish(update_ops, "update")
817
818      non_slot_devices = distribution.extended.non_slot_devices(var_list)
819      finish_updates = distribution.extended.update_non_slot(
820          non_slot_devices, finish, args=(self, update_ops), group=False)
821      if global_step is None:
822        apply_updates = distribution.group(finish_updates, name=name)
823      else:
824        with ops.control_dependencies(finish_updates):
825          apply_updates = distribution.extended.update(
826              global_step, state_ops.assign_add, args=(1,),
827              kwargs={"name": name})
828
829      if not context.executing_eagerly():
830        if isinstance(apply_updates, ops.Tensor):
831          apply_updates = apply_updates.op
832        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
833        if apply_updates not in train_op:
834          train_op.append(apply_updates)
835
836      return apply_updates
837
838  def get_slot(self, var, name):
839    """Return a slot named `name` created for `var` by the Optimizer.
840
841    Some `Optimizer` subclasses use additional variables.  For example
842    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
843    gives access to these `Variable` objects if for some reason you need them.
844
845    Use `get_slot_names()` to get the list of slot names created by the
846    `Optimizer`.
847
848    Args:
849      var: A variable passed to `minimize()` or `apply_gradients()`.
850      name: A string.
851
852    Returns:
853      The `Variable` for the slot if it was created, `None` otherwise.
854    """
855    named_slots = self._slots.get(name, None)
856    if not named_slots:
857      return None
858    slot = named_slots.get(_var_key(var), None)
859    if (distribute_utils.is_distributed_variable(slot) and
860        not distribute_utils.is_distributed_variable(var)):
861      # Make sure var and slot are either both DistributedVariable, or both
862      # per replica variables.
863      slot = slot._get_on_device_or_primary()  # pylint: disable=protected-access
864    return slot
865
866  def get_slot_names(self):
867    """Return a list of the names of slots created by the `Optimizer`.
868
869    See `get_slot()`.
870
871    Returns:
872      A list of strings.
873    """
874    return sorted(self._slots.keys())
875
876  def variables(self):
877    """A list of variables which encode the current state of `Optimizer`.
878
879    Includes slot variables and additional global variables created by the
880    optimizer in the current default graph.
881
882    Returns:
883      A list of variables.
884    """
885    current_graph = ops.get_default_graph()
886
887    def _from_current_graph(variable):
888      if variable._in_graph_mode:  # pylint: disable=protected-access
889        return variable.op.graph is current_graph
890      else:
891        # No variable.op in eager mode. We don't expect lots of eager graphs,
892        # but behavior should be consistent with graph mode.
893        return variable._graph_key == current_graph._graph_key  # pylint: disable=protected-access
894
895    optimizer_variables = [v for v in self._non_slot_variables()
896                           if _from_current_graph(v)]
897    for _, variable_dict in self._slots.items():
898      for _, slot_for_variable in variable_dict.items():
899        if _from_current_graph(slot_for_variable):
900          optimizer_variables.append(slot_for_variable)
901    # Sort variables by name so that the return is deterministic.
902    return sorted(optimizer_variables, key=lambda v: v.name)
903
904  def _create_non_slot_variable(self, initial_value, name, colocate_with):
905    """Add an extra variable, not associated with a slot."""
906    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
907    eager = ops.executing_eagerly_outside_functions()
908    graph = None if eager else colocate_with.graph
909
910    key = (name, graph)
911    v = self._non_slot_dict.get(key, None)
912    if v is None:
913      self._maybe_initialize_trackable()
914      distribution_strategy = distribute_ctx.get_strategy()
915      with distribution_strategy.extended.colocate_vars_with(colocate_with):
916        if eager:
917          restored_initial_value = self._preload_simple_restoration(
918              name=name)
919          if restored_initial_value is not None:
920            initial_value = restored_initial_value
921        v = variable_scope.variable(
922            initial_value, name=name, trainable=False,
923            use_resource=resource_variable_ops.is_resource_variable(
924                colocate_with))
925      # Restore this variable by name if necessary, but don't add a
926      # Trackable dependency. Optimizers return the current graph's
927      # non-slot variables from _checkpoint_dependencies explicitly rather
928      # than unconditionally adding dependencies (since there may be multiple
929      # non-slot variables with the same name in different graphs, trying to
930      # save all of them would result in errors).
931      self._handle_deferred_dependencies(name=name, trackable=v)
932      self._non_slot_dict[key] = v
933
934    return v
935
936  def _trackable_children(self,
937                          save_type=trackable.SaveType.CHECKPOINT,
938                          **kwargs):
939    """From Trackable. Gather graph-specific non-slot variables to save."""
940    current_graph_non_slot_variables = {}
941    current_graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
942    for (name, _), variable_object in sorted(self._non_slot_dict.items(),
943                                             # Avoid comparing graphs
944                                             key=lambda item: item[0][0]):
945      if variable_object._graph_key == current_graph_key:  # pylint: disable=protected-access
946        current_graph_non_slot_variables[name] = variable_object
947    current_graph_non_slot_variables.update(
948        super(Optimizer, self)._trackable_children(save_type, **kwargs))
949    return current_graph_non_slot_variables
950
951  def _lookup_dependency(self, name):
952    """From Trackable. Find a non-slot variable in the current graph."""
953    unconditional = super(Optimizer, self)._lookup_dependency(name)
954    if unconditional is not None:
955      return unconditional
956    graph = None if context.executing_eagerly() else ops.get_default_graph()
957    return self._get_non_slot_variable(name, graph=graph)
958
959  def _get_non_slot_variable(self, name, graph=None):
960    non_slot = self._non_slot_dict.get((name, graph), None)
961    if hasattr(non_slot, "_distributed_container"):
962      # This is a mirrored non-slot.  In order to enable code like `_finish`
963      # to assign to a non-slot, return the current context replica.
964      return non_slot.get()
965    else:
966      return non_slot
967
968  def _non_slot_variables(self):
969    """Additional variables created by the `Optimizer`.
970
971    Returns:
972      A list or tuple of variables.
973    """
974    return self._non_slot_dict.values()
975
976  def _assert_valid_dtypes(self, tensors):
977    """Asserts tensors are all valid types (see `_valid_dtypes`).
978
979    Args:
980      tensors: Tensors to check.
981
982    Raises:
983      ValueError: If any tensor is not a valid type.
984    """
985    valid_dtypes = self._valid_dtypes()
986    for t in tensors:
987      dtype = t.dtype.base_dtype
988      if dtype not in valid_dtypes:
989        raise ValueError(
990            "Invalid type %r for %s, expected: %s." % (
991                dtype, t.name, [v for v in valid_dtypes]))
992
993  # --------------
994  # Methods to be implemented by subclasses if they want to use the
995  # inherited implementation of apply_gradients() or compute_gradients().
996  # --------------
997  def _valid_dtypes(self):
998    """Valid types for loss, variables and gradients.
999
1000    Subclasses should override to allow other float types.
1001
1002    Returns:
1003      Valid types for loss, variables and gradients.
1004    """
1005    return set(
1006        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
1007
1008  def _create_slots(self, var_list):
1009    """Create all slots needed by the variables.
1010
1011    Args:
1012      var_list: A list of `Variable` objects.
1013    """
1014    # No slots needed by default
1015    pass
1016
1017  def _prepare(self):
1018    """Create all needed tensors before applying gradients.
1019
1020    This is called with the name_scope using the "name" that
1021    users have chosen for the application of gradients.
1022    """
1023    pass
1024
1025  def _apply_dense(self, grad, var):
1026    """Add ops to apply dense gradients to `var`.
1027
1028    Args:
1029      grad: A `Tensor`.
1030      var: A `Variable` object.
1031
1032    Returns:
1033      An `Operation`.
1034    """
1035    raise NotImplementedError()
1036
1037  def _resource_apply_dense(self, grad, handle):
1038    """Add ops to apply dense gradients to the variable `handle`.
1039
1040    Args:
1041      grad: a `Tensor` representing the gradient.
1042      handle: a `Tensor` of dtype `resource` which points to the variable
1043       to be updated.
1044
1045    Returns:
1046      An `Operation` which updates the value of the variable.
1047    """
1048    raise NotImplementedError()
1049
1050  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
1051    """Add ops to apply sparse gradients to `handle`, with repeated indices.
1052
1053    Optimizers which override this method must deal with repeated indices. See
1054    the docstring of `_apply_sparse_duplicate_indices` for details. By default
1055    the correct behavior, to sum non-unique indices and their associated
1056    gradients, is enforced by first pre-processing `grad` and `indices` and
1057    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1058    with duplicate indices may instead override this method to avoid the
1059    overhead of summing.
1060
1061    Args:
1062      grad: a `Tensor` representing the gradient for the affected indices.
1063      handle: a `Tensor` of dtype `resource` which points to the variable
1064       to be updated.
1065      indices: a `Tensor` of integral type representing the indices for
1066       which the gradient is nonzero. Indices may be repeated.
1067
1068    Returns:
1069      An `Operation` which updates the value of the variable.
1070    """
1071    summed_grad, unique_indices = _deduplicate_indexed_slices(
1072        values=grad, indices=indices)
1073    return self._resource_apply_sparse(summed_grad, handle, unique_indices)
1074
1075  def _resource_apply_sparse(self, grad, handle, indices):
1076    """Add ops to apply sparse gradients to the variable `handle`.
1077
1078    Similar to `_apply_sparse`, the `indices` argument to this method has been
1079    de-duplicated. Optimizers which deal correctly with non-unique indices may
1080    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1081    overhead.
1082
1083    Args:
1084      grad: a `Tensor` representing the gradient for the affected indices.
1085      handle: a `Tensor` of dtype `resource` which points to the variable
1086       to be updated.
1087      indices: a `Tensor` of integral type representing the indices for
1088       which the gradient is nonzero. Indices are unique.
1089
1090    Returns:
1091      An `Operation` which updates the value of the variable.
1092    """
1093    raise NotImplementedError()
1094
1095  def _apply_sparse_duplicate_indices(self, grad, var):
1096    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
1097
1098    Optimizers which override this method must deal with IndexedSlices objects
1099    such as the following:
1100
1101      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
1102
1103    The correct interpretation is:
1104
1105      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
1106
1107    Many optimizers deal incorrectly with repeated indices when updating based
1108    on sparse gradients (e.g. summing squares rather than squaring the sum, or
1109    applying momentum terms multiple times). Adding first is always the correct
1110    behavior, so this is enforced here by reconstructing the IndexedSlices to
1111    have only unique indices, then calling _apply_sparse.
1112
1113    Optimizers which deal correctly with repeated indices may instead override
1114    this method to avoid the overhead of summing indices.
1115
1116    Args:
1117      grad: `IndexedSlices`.
1118      var: A `Variable` object.
1119
1120    Returns:
1121      An `Operation`.
1122    """
1123    summed_values, unique_indices = _deduplicate_indexed_slices(
1124        values=grad.values, indices=grad.indices)
1125    gradient_no_duplicate_indices = indexed_slices.IndexedSlices(
1126        indices=unique_indices,
1127        values=summed_values,
1128        dense_shape=grad.dense_shape)
1129    return self._apply_sparse(gradient_no_duplicate_indices, var)
1130
1131  def _apply_sparse(self, grad, var):
1132    """Add ops to apply sparse gradients to `var`.
1133
1134    The IndexedSlices object passed to `grad` in this function is by default
1135    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
1136    indices (see its docstring for details). Optimizers which can tolerate or
1137    have correct special cases for duplicate sparse indices may override
1138    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
1139    overhead.
1140
1141    Args:
1142      grad: `IndexedSlices`, with no repeated indices.
1143      var: A `Variable` object.
1144
1145    Returns:
1146      An `Operation`.
1147    """
1148    raise NotImplementedError()
1149
1150  def _finish(self, update_ops, name_scope):
1151    """Do what is needed to finish the update.
1152
1153    This is called with the `name_scope` using the "name" that
1154    users have chosen for the application of gradients.
1155
1156    Args:
1157      update_ops: List of `Operation` objects to update variables.  This list
1158        contains the values returned by the `_apply_dense()` and
1159        `_apply_sparse()` calls.
1160      name_scope: String.  Name to use for the returned operation.
1161
1162    Returns:
1163      The operation to apply updates.
1164    """
1165    return control_flow_ops.group(*update_ops, name=name_scope)
1166
1167  # --------------
1168  # Utility methods for subclasses.
1169  # --------------
1170
1171  def _slot_dict(self, slot_name):
1172    """Returns a dict for caching slots created under the given name.
1173
1174    Args:
1175      slot_name: Name for the slot.
1176
1177    Returns:
1178      A dict that maps primary `Variable` objects to the slot created
1179      for that variable, under the given slot name.
1180    """
1181    named_slots = self._slots.get(slot_name, None)
1182    if named_slots is None:
1183      named_slots = {}
1184      self._slots[slot_name] = named_slots
1185    return named_slots
1186
1187  def _get_or_make_slot(self, var, val, slot_name, op_name):
1188    """Find or create a slot for a variable.
1189
1190    Args:
1191      var: A `Variable` object.
1192      val: A `Tensor`.  The initial value of the slot.
1193      slot_name: Name for the slot.
1194      op_name: Name to use when scoping the Variable that
1195        needs to be created for the slot.
1196
1197    Returns:
1198      A `Variable` object.
1199    """
1200    named_slots = self._slot_dict(slot_name)
1201    if _var_key(var) not in named_slots:
1202      new_slot_variable = slot_creator.create_slot(
1203          var, val, op_name, copy_xla_sharding=True)
1204      self._restore_slot_variable(
1205          slot_name=slot_name, variable=var,
1206          slot_variable=new_slot_variable)
1207      named_slots[_var_key(var)] = new_slot_variable
1208    return named_slots[_var_key(var)]
1209
1210  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
1211                                         slot_name, op_name):
1212    """Find or create a slot for a variable, using an Initializer.
1213
1214    Args:
1215      var: A `Variable` object.
1216      initializer: An `Initializer`.  The initial value of the slot.
1217      shape: Shape of the initial value of the slot.
1218      dtype: Type of the value of the slot.
1219      slot_name: Name for the slot.
1220      op_name: Name to use when scoping the Variable that
1221        needs to be created for the slot.
1222
1223    Returns:
1224      A `Variable` object.
1225    """
1226    named_slots = self._slot_dict(slot_name)
1227    if _var_key(var) not in named_slots:
1228      new_slot_variable = slot_creator.create_slot_with_initializer(
1229          var, initializer, shape, dtype, op_name, copy_xla_sharding=True)
1230      self._restore_slot_variable(
1231          slot_name=slot_name, variable=var,
1232          slot_variable=new_slot_variable)
1233      named_slots[_var_key(var)] = new_slot_variable
1234    return named_slots[_var_key(var)]
1235
1236  def _zeros_slot(self, var, slot_name, op_name):
1237    """Find or create a slot initialized with 0.0.
1238
1239    Args:
1240      var: A `Variable` object.
1241      slot_name: Name for the slot.
1242      op_name: Name to use when scoping the Variable that
1243        needs to be created for the slot.
1244
1245    Returns:
1246      A `Variable` object.
1247    """
1248    named_slots = self._slot_dict(slot_name)
1249    if _var_key(var) not in named_slots:
1250      new_slot_variable = slot_creator.create_zeros_slot(
1251          var, op_name, copy_xla_sharding=True)
1252      self._restore_slot_variable(
1253          slot_name=slot_name, variable=var,
1254          slot_variable=new_slot_variable)
1255      named_slots[_var_key(var)] = new_slot_variable
1256    return named_slots[_var_key(var)]
1257
1258  # --------------
1259  # For implementing the Trackable interface.
1260  # --------------
1261
1262  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1263    """Restore a newly created slot variable's value."""
1264    variable_key = _var_key(variable)
1265    deferred_restorations = self._deferred_slot_restorations.get(
1266        slot_name, {}).pop(variable_key, [])
1267    # Iterate over restores, highest restore UID first to minimize the number
1268    # of assignments.
1269    deferred_restorations.sort(key=lambda position: position.restore_uid,
1270                               reverse=True)
1271    for checkpoint_position in deferred_restorations:
1272      checkpoint_position.restore(slot_variable)
1273
1274  def _create_or_restore_slot_variable(
1275      self, slot_variable_position, slot_name, variable):
1276    """Restore a slot variable's value, possibly creating it.
1277
1278    Called when a variable which has an associated slot variable is created or
1279    restored. When executing eagerly, we create the slot variable with a
1280    restoring initializer.
1281
1282    No new variables are created when graph building. Instead,
1283    _restore_slot_variable catches these after normal creation and adds restore
1284    ops to the graph. This method is nonetheless important when graph building
1285    for the case when a slot variable has already been created but `variable`
1286    has just been added to a dependency graph (causing us to realize that the
1287    slot variable needs to be restored).
1288
1289    Args:
1290      slot_variable_position: A `trackable._CheckpointPosition` object
1291        indicating the slot variable `Trackable` object to be restored.
1292      slot_name: The name of this `Optimizer`'s slot to restore into.
1293      variable: The variable object this slot is being created for.
1294    """
1295    named_slots = self._slot_dict(slot_name)
1296    variable_key = _var_key(variable)
1297    slot_variable = named_slots.get(variable_key, None)
1298    if (slot_variable is None and context.executing_eagerly() and
1299        slot_variable_position.is_simple_variable()
1300        # Defer slot variable creation if there is an active variable creator
1301        # scope. Generally we'd like to eagerly create/restore slot variables
1302        # when possible, but this may mean that scopes intended to catch
1303        # `variable` also catch its eagerly created slot variable
1304        # unintentionally (specifically make_template would add a dependency on
1305        # a slot variable if not for this case). Deferring is mostly harmless
1306        # (aside from double initialization), and makes variable creator scopes
1307        # behave the same way they do when graph building.
1308        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1309      initializer = trackable.CheckpointInitialValueCallable(
1310          checkpoint_position=slot_variable_position)
1311      # CheckpointInitialValueCallable will ignore the shape and dtype
1312      # parameters but they must be passed.
1313      slot_variable = self._get_or_make_slot_with_initializer(
1314          var=variable,
1315          initializer=initializer,
1316          shape=variable.shape,
1317          dtype=variable.dtype,
1318          slot_name=slot_name,
1319          op_name=self._name)
1320      # Slot variables are not owned by any one object (because we don't want to
1321      # save the slot variable if the optimizer is saved without the non-slot
1322      # variable, or if the non-slot variable is saved without the optimizer;
1323      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1324      # variable, variable)). So we don't _track_ slot variables anywhere, and
1325      # instead special-case this dependency and otherwise pretend it's a normal
1326      # graph.
1327    if slot_variable is not None:
1328      # If we've either made this slot variable, or if we've pulled out an
1329      # existing slot variable, we should restore it.
1330      slot_variable_position.restore(slot_variable)
1331    else:
1332      # We didn't make the slot variable. Defer restoring until it gets created
1333      # normally. We keep a list rather than the one with the highest restore
1334      # UID in case slot variables have their own dependencies, in which case
1335      # those could differ between restores.
1336      self._deferred_slot_restorations.setdefault(
1337          slot_name, {}).setdefault(variable_key, []).append(
1338              slot_variable_position)
1339
1340  def _call_if_callable(self, param):
1341    """Call the function if param is callable."""
1342    return param() if callable(param) else param
1343