1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Base class for optimizers.""" 17# pylint: disable=g-bad-name 18 19import abc 20 21from tensorflow.python.distribute import distribute_lib 22from tensorflow.python.distribute import distribute_utils 23from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 24from tensorflow.python.distribute import reduce_util as ds_reduce_util 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import indexed_slices 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import gradients 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import state_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.ops import variables 38from tensorflow.python.trackable import base as trackable 39from tensorflow.python.training import slot_creator 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import tf_export 42 43 44def get_filtered_grad_fn(grad_fn): 45 # `distributed_context.join()` requires that its arguments are parallel 46 # across threads, and in particular that `grads_and_vars` has the same 47 # variables in the same order. 48 49 # When computing gradients in eager mode with multiple threads, you 50 # can get extra variables with a gradient of `None`. This happens when 51 # those variables are accessed in another thread during the gradient 52 # computation. To get a consistent set of variables, we filter out 53 # those with `None` gradients. 54 def filtered_grad_fn(*args, **kwargs): 55 return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None] 56 57 return filtered_grad_fn 58 59 60def _deduplicate_indexed_slices(values, indices): 61 """Sums `values` associated with any non-unique `indices`. 62 63 Args: 64 values: A `Tensor` with rank >= 1. 65 indices: A one-dimensional integer `Tensor`, indexing into the first 66 dimension of `values` (as in an IndexedSlices object). 67 Returns: 68 A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a 69 de-duplicated version of `indices` and `summed_values` contains the sum of 70 `values` slices associated with each unique index. 71 """ 72 unique_indices, new_index_positions = array_ops.unique(indices) 73 summed_values = math_ops.unsorted_segment_sum( 74 values, new_index_positions, 75 array_ops.shape(unique_indices)[0]) 76 return (summed_values, unique_indices) 77 78 79def _var_key(var): 80 """Returns slot key for `var`.""" 81 # pylint: disable=protected-access 82 if hasattr(var, "_distributed_container"): 83 var = var._distributed_container() 84 if (distribute_utils.is_distributed_variable(var) and 85 not ops.executing_eagerly_outside_functions()): 86 return (var.graph, var._shared_name) 87 if hasattr(var, "op"): 88 return (var.op.graph, var.op.name) 89 return var._unique_id 90 # pylint: enable=protected-access 91 92 93class _OptimizableVariable(metaclass=abc.ABCMeta): 94 """Interface for abstracting over variables in the optimizers.""" 95 96 @abc.abstractmethod 97 def target(self): 98 """Returns the optimization target for this variable.""" 99 raise NotImplementedError("Calling an abstract method.") 100 101 @abc.abstractmethod 102 def update_op(self, optimizer, g): 103 """Returns the update ops for updating the variable.""" 104 raise NotImplementedError("Calling an abstract method.") 105 106 107class _RefVariableProcessor(_OptimizableVariable): 108 """Processor for Variable.""" 109 110 def __init__(self, v): 111 self._v = v 112 113 def __str__(self): 114 return "<_RefVariableProcessor(%s)>" % self._v 115 116 def target(self): 117 return self._v._ref() # pylint: disable=protected-access 118 119 def update_op(self, optimizer, g): 120 if isinstance(g, ops.Tensor): 121 update_op = optimizer._apply_dense(g, self._v) # pylint: disable=protected-access 122 if self._v.constraint is not None: 123 with ops.control_dependencies([update_op]): 124 return self._v.assign(self._v.constraint(self._v)) 125 else: 126 return update_op 127 else: 128 assert isinstance(g, indexed_slices.IndexedSlices), ( 129 "Gradient ", g, " is neither a tensor nor IndexedSlices.") 130 if self._v.constraint is not None: 131 raise RuntimeError( 132 "Cannot use a constraint function on a sparse variable.") 133 # pylint: disable=protected-access 134 return optimizer._apply_sparse_duplicate_indices(g, self._v) 135 136 137class _DenseReadResourceVariableProcessor(_OptimizableVariable): 138 """Processor for dense ResourceVariables.""" 139 140 def __init__(self, v): 141 self._v = v 142 143 def target(self): 144 return self._v 145 146 def update_op(self, optimizer, g): 147 # pylint: disable=protected-access 148 update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0]) 149 if self._v.constraint is not None: 150 with ops.control_dependencies([update_op]): 151 return self._v.assign(self._v.constraint(self._v)) 152 else: 153 return update_op 154 155 156class _DenseResourceVariableProcessor(_OptimizableVariable): 157 """Processor for dense ResourceVariables.""" 158 159 def __init__(self, v): 160 self._v = v 161 162 def target(self): 163 return self._v 164 165 def update_op(self, optimizer, g): 166 # pylint: disable=protected-access 167 if isinstance(g, indexed_slices.IndexedSlices): 168 if self._v.constraint is not None: 169 raise RuntimeError( 170 "Cannot use a constraint function on a sparse variable.") 171 return optimizer._resource_apply_sparse_duplicate_indices( 172 g.values, self._v, g.indices) 173 update_op = optimizer._resource_apply_dense(g, self._v) 174 if self._v.constraint is not None: 175 with ops.control_dependencies([update_op]): 176 return self._v.assign(self._v.constraint(self._v)) 177 else: 178 return update_op 179 180 181class _TensorProcessor(_OptimizableVariable): 182 """Processor for ordinary Tensors. 183 184 Even though a Tensor can't really be updated, sometimes it is useful to 185 compute the gradients with respect to a Tensor using the optimizer. Updating 186 the Tensor is, of course, unsupported. 187 """ 188 189 def __init__(self, v): 190 self._v = v 191 192 def target(self): 193 return self._v 194 195 def update_op(self, optimizer, g): 196 raise NotImplementedError("Trying to update a Tensor ", self._v) 197 198 199def _get_processor(v): 200 """The processor of v.""" 201 if context.executing_eagerly(): 202 if isinstance(v, ops.Tensor): 203 return _TensorProcessor(v) 204 else: 205 return _DenseResourceVariableProcessor(v) 206 if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access 207 # True if and only if `v` was initialized eagerly. 208 return _DenseResourceVariableProcessor(v) 209 if v.op.type == "VarHandleOp": 210 return _DenseResourceVariableProcessor(v) 211 if isinstance(v, variables.Variable): 212 return _RefVariableProcessor(v) 213 if isinstance(v, ops.Tensor): 214 return _TensorProcessor(v) 215 raise NotImplementedError("Trying to optimize unsupported type ", v) 216 217 218@tf_export(v1=["train.Optimizer"]) 219class Optimizer( 220 # Optimizers inherit from Trackable rather than AutoTrackable 221 # since they do most of their dependency management themselves (slot 222 # variables are special-cased, and non-slot variables are keyed to graphs). 223 trackable.Trackable): 224 """Base class for optimizers. 225 226 This class defines the API to add Ops to train a model. You never use this 227 class directly, but instead instantiate one of its subclasses such as 228 `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`. 229 230 ### Usage 231 232 ```python 233 # Create an optimizer with the desired parameters. 234 opt = GradientDescentOptimizer(learning_rate=0.1) 235 # Add Ops to the graph to minimize a cost by updating a list of variables. 236 # "cost" is a Tensor, and the list of variables contains tf.Variable 237 # objects. 238 opt_op = opt.minimize(cost, var_list=<list of variables>) 239 ``` 240 241 In the training program you will just have to run the returned Op. 242 243 ```python 244 # Execute opt_op to do one step of training: 245 opt_op.run() 246 ``` 247 248 ### Processing gradients before applying them. 249 250 Calling `minimize()` takes care of both computing the gradients and 251 applying them to the variables. If you want to process the gradients 252 before applying them you can instead use the optimizer in three steps: 253 254 1. Compute the gradients with `compute_gradients()`. 255 2. Process the gradients as you wish. 256 3. Apply the processed gradients with `apply_gradients()`. 257 258 Example: 259 260 ```python 261 # Create an optimizer. 262 opt = GradientDescentOptimizer(learning_rate=0.1) 263 264 # Compute the gradients for a list of variables. 265 grads_and_vars = opt.compute_gradients(loss, <list of variables>) 266 267 # grads_and_vars is a list of tuples (gradient, variable). Do whatever you 268 # need to the 'gradient' part, for example cap them, etc. 269 capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars] 270 271 # Ask the optimizer to apply the capped gradients. 272 opt.apply_gradients(capped_grads_and_vars) 273 ``` 274 275 ### Gating Gradients 276 277 Both `minimize()` and `compute_gradients()` accept a `gate_gradients` 278 argument that controls the degree of parallelism during the application of 279 the gradients. 280 281 The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. 282 283 <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides 284 the maximum parallelism in execution, at the cost of some non-reproducibility 285 in the results. For example the two gradients of `matmul` depend on the input 286 values: With `GATE_NONE` one of the gradients could be applied to one of the 287 inputs _before_ the other gradient is computed resulting in non-reproducible 288 results. 289 290 <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before 291 they are used. This prevents race conditions for Ops that generate gradients 292 for multiple inputs where the gradients depend on the inputs. 293 294 <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed 295 before any one of them is used. This provides the least parallelism but can 296 be useful if you want to process all gradients before applying any of them. 297 298 ### Slots 299 300 Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer` 301 allocate and manage additional variables associated with the variables to 302 train. These are called <i>Slots</i>. Slots have names and you can ask the 303 optimizer for the names of the slots that it uses. Once you have a slot name 304 you can ask the optimizer for the variable it created to hold the slot value. 305 306 This can be useful if you want to log debug a training algorithm, report stats 307 about the slots, etc. 308 309 @compatibility(TF2) 310 `tf.compat.v1.train.Optimizer` can be used in eager mode and `tf.function`, 311 but it is not recommended. Please use the subclasses of 312 `tf.keras.optimizers.Optimizer` instead in TF2. Please see [Basic training 313 loops](https://www.tensorflow.org/guide/basic_training_loops) or 314 [Writing a training loop from scratch] 315 (https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) 316 for examples. 317 318 If your TF1 code contains a `tf.compat.v1.train.Optimizer` symbol, whether it 319 is used with or without a `tf.estimator.Estimator`, you cannot simply replace 320 that with the corresponding `tf.keras.optimizers.Optimizer`s. To migrate to 321 TF2, it is advised the whole training program used with `Estimator` to be 322 migrated to Keras `Model.fit` based or TF2 custom training loops. 323 324 #### Structural Mapping to Native TF2 325 326 Before: 327 328 ```python 329 sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0) 330 opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) 331 opt_op.run(session=session) 332 ``` 333 334 After: 335 336 ```python 337 sgd = tf.keras.optimizers.SGD(3.0) 338 sgd.minimize(cost_fn, [var0, var1]) 339 ``` 340 341 #### How to Map Arguments 342 343 | TF1 Arg Name | TF2 Arg Name | Note | 344 | :-------------------- | :-------------- | :------------------------- | 345 | `use_locking` | Not supported | - | 346 | `name` | `name. ` | - | 347 348 #### Before & After Usage Example 349 350 Before: 351 352 >>> g = tf.compat.v1.Graph() 353 >>> with g.as_default(): 354 ... var0 = tf.compat.v1.Variable([1.0, 2.0]) 355 ... var1 = tf.compat.v1.Variable([3.0, 4.0]) 356 ... cost = 5 * var0 + 3 * var1 357 ... global_step = tf.compat.v1.Variable( 358 ... tf.compat.v1.zeros([], tf.compat.v1.int64), name='global_step') 359 ... init_op = tf.compat.v1.initialize_all_variables() 360 ... sgd_op = tf.compat.v1.train.GradientDescentOptimizer(3.0) 361 ... opt_op = sgd_op.minimize(cost, global_step, [var0, var1]) 362 >>> session = tf.compat.v1.Session(graph=g) 363 >>> session.run(init_op) 364 >>> opt_op.run(session=session) 365 >>> print(session.run(var0)) 366 [-14. -13.] 367 368 369 After: 370 >>> var0 = tf.Variable([1.0, 2.0]) 371 >>> var1 = tf.Variable([3.0, 4.0]) 372 >>> cost_fn = lambda: 5 * var0 + 3 * var1 373 >>> sgd = tf.keras.optimizers.SGD(3.0) 374 >>> sgd.minimize(cost_fn, [var0, var1]) 375 >>> print(var0.numpy()) 376 [-14. -13.] 377 378 @end_compatibility 379 380 381 """ 382 383 # Values for gate_gradients. 384 GATE_NONE = 0 385 GATE_OP = 1 386 GATE_GRAPH = 2 387 388 def __init__(self, use_locking, name): 389 """Create a new Optimizer. 390 391 This must be called by the constructors of subclasses. 392 393 Args: 394 use_locking: Bool. If True apply use locks to prevent concurrent updates 395 to variables. 396 name: A non-empty string. The name to use for accumulators created 397 for the optimizer. 398 399 Raises: 400 ValueError: If name is malformed. 401 """ 402 if not name: 403 raise ValueError("Must specify the optimizer name") 404 self._use_locking = use_locking 405 self._name = name 406 # Dictionary of slots. 407 # {slot_name : 408 # {_var_key(variable_to_train): slot_for_the_variable, ... }, 409 # ... } 410 self._slots = {} 411 self._non_slot_dict = {} 412 # For implementing Trackable. Stores information about how to restore 413 # slot variables which have not yet been created 414 # (trackable._CheckpointPosition objects). 415 # {slot_name : 416 # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, 417 # ... } 418 self._deferred_slot_restorations = {} 419 420 # TODO(isaprykin): When using a DistributionStrategy, and when an 421 # optimizer is created in each replica, it might be dangerous to 422 # rely on some Optimizer methods. When such methods are called on a 423 # per-replica optimizer, an exception needs to be thrown. We do 424 # allow creation per-replica optimizers however, because the 425 # compute_gradients()->apply_gradients() sequence is safe. 426 427 def get_name(self): 428 return self._name 429 430 def minimize(self, loss, global_step=None, var_list=None, 431 gate_gradients=GATE_OP, aggregation_method=None, 432 colocate_gradients_with_ops=False, name=None, 433 grad_loss=None): 434 """Add operations to minimize `loss` by updating `var_list`. 435 436 This method simply combines calls `compute_gradients()` and 437 `apply_gradients()`. If you want to process the gradient before applying 438 them call `compute_gradients()` and `apply_gradients()` explicitly instead 439 of using this function. 440 441 Args: 442 loss: A `Tensor` containing the value to minimize. 443 global_step: Optional `Variable` to increment by one after the 444 variables have been updated. 445 var_list: Optional list or tuple of `Variable` objects to update to 446 minimize `loss`. Defaults to the list of variables collected in 447 the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. 448 gate_gradients: How to gate the computation of gradients. Can be 449 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 450 aggregation_method: Specifies the method used to combine gradient terms. 451 Valid values are defined in the class `AggregationMethod`. 452 colocate_gradients_with_ops: If True, try colocating gradients with 453 the corresponding op. 454 name: Optional name for the returned operation. 455 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 456 457 Returns: 458 An Operation that updates the variables in `var_list`. If `global_step` 459 was not `None`, that operation also increments `global_step`. 460 461 Raises: 462 ValueError: If some of the variables are not `Variable` objects. 463 464 @compatibility(eager) 465 When eager execution is enabled, `loss` should be a Python function that 466 takes no arguments and computes the value to be minimized. Minimization (and 467 gradient computation) is done with respect to the elements of `var_list` if 468 not None, else with respect to any trainable variables created during the 469 execution of the `loss` function. `gate_gradients`, `aggregation_method`, 470 `colocate_gradients_with_ops` and `grad_loss` are ignored when eager 471 execution is enabled. 472 @end_compatibility 473 """ 474 grads_and_vars = self.compute_gradients( 475 loss, var_list=var_list, gate_gradients=gate_gradients, 476 aggregation_method=aggregation_method, 477 colocate_gradients_with_ops=colocate_gradients_with_ops, 478 grad_loss=grad_loss) 479 480 vars_with_grad = [v for g, v in grads_and_vars if g is not None] 481 if not vars_with_grad: 482 raise ValueError( 483 "No gradients provided for any variable, check your graph for ops" 484 " that do not support gradients, between variables %s and loss %s." % 485 ([str(v) for _, v in grads_and_vars], loss)) 486 487 return self.apply_gradients(grads_and_vars, global_step=global_step, 488 name=name) 489 490 def compute_gradients(self, loss, var_list=None, 491 gate_gradients=GATE_OP, 492 aggregation_method=None, 493 colocate_gradients_with_ops=False, 494 grad_loss=None): 495 """Compute gradients of `loss` for the variables in `var_list`. 496 497 This is the first part of `minimize()`. It returns a list 498 of (gradient, variable) pairs where "gradient" is the gradient 499 for "variable". Note that "gradient" can be a `Tensor`, an 500 `IndexedSlices`, or `None` if there is no gradient for the 501 given variable. 502 503 @compatibility(TF2) 504 `tf.keras.optimizers.Optimizer` in TF2 does not provide a 505 `compute_gradients` method, and you should use a `tf.GradientTape` to 506 obtain the gradients: 507 508 ```python 509 @tf.function 510 def train step(inputs): 511 batch_data, labels = inputs 512 with tf.GradientTape() as tape: 513 predictions = model(batch_data, training=True) 514 loss = tf.keras.losses.CategoricalCrossentropy( 515 reduction=tf.keras.losses.Reduction.NONE)(labels, predictions) 516 gradients = tape.gradient(loss, model.trainable_variables) 517 optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 518 ``` 519 520 Args: 521 loss: A Tensor containing the value to minimize or a callable taking 522 no arguments which returns the value to minimize. When eager execution 523 is enabled it must be a callable. 524 var_list: Optional list or tuple of `tf.Variable` to update to minimize 525 `loss`. Defaults to the list of variables collected in the graph 526 under the key `GraphKeys.TRAINABLE_VARIABLES`. 527 gate_gradients: How to gate the computation of gradients. Can be 528 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 529 aggregation_method: Specifies the method used to combine gradient terms. 530 Valid values are defined in the class `AggregationMethod`. 531 colocate_gradients_with_ops: If True, try colocating gradients with 532 the corresponding op. 533 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 534 535 Returns: 536 A list of (gradient, variable) pairs. Variable is always present, but 537 gradient can be `None`. 538 539 Raises: 540 TypeError: If `var_list` contains anything else than `Variable` objects. 541 ValueError: If some arguments are invalid. 542 RuntimeError: If called with eager execution enabled and `loss` is 543 not callable. 544 545 @compatibility(eager) 546 When eager execution is enabled, `gate_gradients`, `aggregation_method`, 547 and `colocate_gradients_with_ops` are ignored. 548 @end_compatibility 549 """ 550 if callable(loss): 551 with backprop.GradientTape() as tape: 552 if var_list is not None: 553 tape.watch(var_list) 554 loss_value = loss() 555 556 # Scale loss if using a "mean" loss reduction and multiple replicas. 557 # Have to be careful to call distribute_lib.get_loss_reduction() 558 # *after* loss() is evaluated, so we know what loss reduction it uses. 559 # TODO(josh11b): Test that we handle weight decay in a reasonable way. 560 loss_value = self._scale_loss(loss_value) 561 562 if var_list is None: 563 var_list = tape.watched_variables() 564 # TODO(jhseu): Figure out why GradientTape's gradients don't require loss 565 # to be executed. 566 with ops.control_dependencies([loss_value]): 567 grads = tape.gradient(loss_value, var_list, grad_loss) 568 return list(zip(grads, var_list)) 569 570 # Non-callable/Tensor loss case 571 if context.executing_eagerly(): 572 raise RuntimeError( 573 "`loss` passed to Optimizer.compute_gradients should " 574 "be a function when eager execution is enabled.") 575 576 # Scale loss if using a "mean" loss reduction and multiple replicas. 577 loss = self._scale_loss(loss) 578 579 if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, 580 Optimizer.GATE_GRAPH]: 581 raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " 582 "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" % 583 gate_gradients) 584 self._assert_valid_dtypes([loss]) 585 if grad_loss is not None: 586 self._assert_valid_dtypes([grad_loss]) 587 if var_list is None: 588 var_list = ( 589 variables.trainable_variables() + 590 ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 591 else: 592 var_list = nest.flatten(var_list) 593 # pylint: disable=protected-access 594 var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS) 595 # pylint: enable=protected-access 596 processors = [_get_processor(v) for v in var_list] 597 if not var_list: 598 raise ValueError("No variables to optimize.") 599 var_refs = [p.target() for p in processors] 600 grads = gradients.gradients( 601 loss, var_refs, grad_ys=grad_loss, 602 gate_gradients=(gate_gradients == Optimizer.GATE_OP), 603 aggregation_method=aggregation_method, 604 colocate_gradients_with_ops=colocate_gradients_with_ops) 605 if gate_gradients == Optimizer.GATE_GRAPH: 606 grads = control_flow_ops.tuple(grads) 607 grads_and_vars = list(zip(grads, var_list)) 608 self._assert_valid_dtypes( 609 [v for g, v in grads_and_vars 610 if g is not None and v.dtype != dtypes.resource]) 611 return grads_and_vars 612 613 @staticmethod 614 def _scale_loss(loss_value): 615 ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access 616 if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: 617 num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync 618 if num_replicas > 1: 619 loss_value *= (1. / num_replicas) 620 ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access 621 return loss_value 622 623 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 624 """Apply gradients to variables. 625 626 This is the second part of `minimize()`. It returns an `Operation` that 627 applies gradients. 628 629 @compatibility(TF2) 630 #### How to Map Arguments 631 632 | TF1 Arg Name | TF2 Arg Name | Note | 633 | :-------------------- | :-------------- | :------------------------- | 634 | `grads_and_vars` | `grads_and_vars`| - | 635 | `global_step` | Not supported. | Use `optimizer.iterations` | 636 | `name` | `name. ` | - | 637 638 Args: 639 grads_and_vars: List of (gradient, variable) pairs as returned by 640 `compute_gradients()`. 641 global_step: Optional `Variable` to increment by one after the 642 variables have been updated. 643 name: Optional name for the returned operation. Default to the 644 name passed to the `Optimizer` constructor. 645 646 Returns: 647 An `Operation` that applies the specified gradients. If `global_step` 648 was not None, that operation also increments `global_step`. 649 650 Raises: 651 TypeError: If `grads_and_vars` is malformed. 652 ValueError: If none of the variables have gradients. 653 RuntimeError: If you should use `_distributed_apply()` instead. 654 """ 655 # This is a default implementation of apply_gradients() that can be shared 656 # by most optimizers. It relies on the subclass implementing the following 657 # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). 658 659 # TODO(isaprykin): Get rid of `has_strategy()` check by 660 # always calling _distributed_apply(), using the default distribution 661 # as needed. 662 if distribute_ctx.has_strategy(): 663 # Handle DistributionStrategy case. 664 if distribute_ctx.in_cross_replica_context(): 665 raise RuntimeError("Use `_distributed_apply()` instead of " 666 "`apply_gradients()` in a cross-replica context.") 667 668 grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)() 669 return distribute_ctx.get_replica_context().merge_call( 670 self._distributed_apply, args=(grads_and_vars, global_step, name)) 671 672 # No DistributionStrategy case. 673 grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. 674 if not grads_and_vars: 675 raise ValueError("No variables provided.") 676 converted_grads_and_vars = [] 677 for g, v in grads_and_vars: 678 if g is not None: 679 try: 680 # Convert the grad to Tensor or IndexedSlices if necessary. 681 g = ops.convert_to_tensor_or_indexed_slices(g) 682 except TypeError: 683 raise TypeError( 684 "Gradient must be convertible to a Tensor" 685 " or IndexedSlices, or None: %s" % g) 686 if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): 687 raise TypeError( 688 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 689 p = _get_processor(v) 690 converted_grads_and_vars.append((g, v, p)) 691 692 converted_grads_and_vars = tuple(converted_grads_and_vars) 693 var_list = [v for g, v, _ in converted_grads_and_vars if g is not None] 694 if not var_list: 695 raise ValueError("No gradients provided for any variable: %s." % 696 ([str(v) for _, v, _ in converted_grads_and_vars],)) 697 with ops.init_scope(): 698 self._create_slots(var_list) 699 update_ops = [] 700 with ops.name_scope(name, self._name, skip_on_eager=False) as name: 701 self._prepare() 702 for grad, var, processor in converted_grads_and_vars: 703 if grad is None: 704 continue 705 # We colocate all ops created in _apply_dense or _apply_sparse 706 # on the same device as the variable. 707 # TODO(apassos): figure out how to get the variable name here. 708 if (context.executing_eagerly() or 709 resource_variable_ops.is_resource_variable(var) 710 and not var._in_graph_mode): # pylint: disable=protected-access 711 scope_name = "" 712 else: 713 scope_name = var.op.name 714 with ops.name_scope( 715 "update_" + scope_name, 716 skip_on_eager=False), ops.colocate_with(var): 717 update_ops.append(processor.update_op(self, grad)) 718 if global_step is None: 719 apply_updates = self._finish(update_ops, name) 720 else: 721 with ops.control_dependencies([self._finish(update_ops, "update")]): 722 with ops.colocate_with(global_step): 723 if isinstance( 724 global_step, resource_variable_ops.BaseResourceVariable): 725 # TODO(apassos): the implicit read in assign_add is slow; consider 726 # making it less so. 727 apply_updates = resource_variable_ops.assign_add_variable_op( 728 global_step.handle, 729 ops.convert_to_tensor(1, dtype=global_step.dtype), 730 name=name) 731 else: 732 apply_updates = state_ops.assign_add(global_step, 1, name=name) 733 734 if not context.executing_eagerly(): 735 if isinstance(apply_updates, ops.Tensor): 736 apply_updates = apply_updates.op 737 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 738 if apply_updates not in train_op: 739 train_op.append(apply_updates) 740 741 return apply_updates 742 743 def _distributed_apply(self, 744 distribution, 745 grads_and_vars, 746 global_step=None, 747 name=None): 748 """A version of `apply_gradients` for cross-replica context. 749 750 This is a version of `apply_gradients()` for when you are using a 751 `DistributionStrategy` and are in a cross-replica context. If in a 752 replica context, use `apply_gradients()` as normal. 753 754 Args: 755 distribution: A `DistributionStrategy` object. 756 grads_and_vars: List of (gradient, variable) pairs as returned by 757 `compute_gradients()`, and then aggregated across replicas. 758 global_step: Optional (mirrored) `Variable` to increment by one 759 after the variables have been updated. 760 name: Optional name for the returned operation. Default to the 761 name passed to the `Optimizer` constructor. 762 763 Returns: 764 An `Operation` that applies the specified gradients across all 765 replicas. If `global_step` was not None, that operation also 766 increments `global_step` 767 """ 768 reduced_grads = distribution.extended.batch_reduce_to( 769 ds_reduce_util.ReduceOp.SUM, grads_and_vars) 770 var_list = [v for _, v in grads_and_vars] 771 grads_and_vars = zip(reduced_grads, var_list) 772 773 # Note that this is called in a cross-replica context. 774 with ops.init_scope(): 775 self._create_slots(var_list) 776 777 def update(v, g): 778 """Apply gradients to a replica variable.""" 779 assert v is not None 780 781 try: 782 # Convert the grad to Tensor or IndexedSlices if necessary. 783 g = ops.convert_to_tensor_or_indexed_slices(g) 784 except TypeError: 785 raise TypeError("Gradient must be convertible to a Tensor" 786 " or IndexedSlices, or None: %s" % g) 787 if not isinstance(g, (ops.Tensor, indexed_slices.IndexedSlices)): 788 raise TypeError( 789 "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) 790 p = _get_processor(v) 791 792 if context.executing_eagerly() or ( 793 resource_variable_ops.is_resource_variable(v) and 794 not v._in_graph_mode): # pylint: disable=protected-access 795 scope_name = v.name.split(":")[0] 796 else: 797 scope_name = v.op.name 798 799 # device_policy is set because non-mirrored tensors will be read in 800 # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t` 801 # is an example. 802 with ops.name_scope("update_" + scope_name): 803 return p.update_op(self, g) 804 805 with ops.name_scope(name, self._name) as name: 806 self._prepare() 807 808 update_ops = [ 809 op 810 for grad, var in grads_and_vars 811 for op in distribution.extended.update( 812 var, update, args=(grad,), group=False) 813 ] 814 815 def finish(self, update_ops): 816 return self._finish(update_ops, "update") 817 818 non_slot_devices = distribution.extended.non_slot_devices(var_list) 819 finish_updates = distribution.extended.update_non_slot( 820 non_slot_devices, finish, args=(self, update_ops), group=False) 821 if global_step is None: 822 apply_updates = distribution.group(finish_updates, name=name) 823 else: 824 with ops.control_dependencies(finish_updates): 825 apply_updates = distribution.extended.update( 826 global_step, state_ops.assign_add, args=(1,), 827 kwargs={"name": name}) 828 829 if not context.executing_eagerly(): 830 if isinstance(apply_updates, ops.Tensor): 831 apply_updates = apply_updates.op 832 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 833 if apply_updates not in train_op: 834 train_op.append(apply_updates) 835 836 return apply_updates 837 838 def get_slot(self, var, name): 839 """Return a slot named `name` created for `var` by the Optimizer. 840 841 Some `Optimizer` subclasses use additional variables. For example 842 `Momentum` and `Adagrad` use variables to accumulate updates. This method 843 gives access to these `Variable` objects if for some reason you need them. 844 845 Use `get_slot_names()` to get the list of slot names created by the 846 `Optimizer`. 847 848 Args: 849 var: A variable passed to `minimize()` or `apply_gradients()`. 850 name: A string. 851 852 Returns: 853 The `Variable` for the slot if it was created, `None` otherwise. 854 """ 855 named_slots = self._slots.get(name, None) 856 if not named_slots: 857 return None 858 slot = named_slots.get(_var_key(var), None) 859 if (distribute_utils.is_distributed_variable(slot) and 860 not distribute_utils.is_distributed_variable(var)): 861 # Make sure var and slot are either both DistributedVariable, or both 862 # per replica variables. 863 slot = slot._get_on_device_or_primary() # pylint: disable=protected-access 864 return slot 865 866 def get_slot_names(self): 867 """Return a list of the names of slots created by the `Optimizer`. 868 869 See `get_slot()`. 870 871 Returns: 872 A list of strings. 873 """ 874 return sorted(self._slots.keys()) 875 876 def variables(self): 877 """A list of variables which encode the current state of `Optimizer`. 878 879 Includes slot variables and additional global variables created by the 880 optimizer in the current default graph. 881 882 Returns: 883 A list of variables. 884 """ 885 current_graph = ops.get_default_graph() 886 887 def _from_current_graph(variable): 888 if variable._in_graph_mode: # pylint: disable=protected-access 889 return variable.op.graph is current_graph 890 else: 891 # No variable.op in eager mode. We don't expect lots of eager graphs, 892 # but behavior should be consistent with graph mode. 893 return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access 894 895 optimizer_variables = [v for v in self._non_slot_variables() 896 if _from_current_graph(v)] 897 for _, variable_dict in self._slots.items(): 898 for _, slot_for_variable in variable_dict.items(): 899 if _from_current_graph(slot_for_variable): 900 optimizer_variables.append(slot_for_variable) 901 # Sort variables by name so that the return is deterministic. 902 return sorted(optimizer_variables, key=lambda v: v.name) 903 904 def _create_non_slot_variable(self, initial_value, name, colocate_with): 905 """Add an extra variable, not associated with a slot.""" 906 # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. 907 eager = ops.executing_eagerly_outside_functions() 908 graph = None if eager else colocate_with.graph 909 910 key = (name, graph) 911 v = self._non_slot_dict.get(key, None) 912 if v is None: 913 self._maybe_initialize_trackable() 914 distribution_strategy = distribute_ctx.get_strategy() 915 with distribution_strategy.extended.colocate_vars_with(colocate_with): 916 if eager: 917 restored_initial_value = self._preload_simple_restoration( 918 name=name) 919 if restored_initial_value is not None: 920 initial_value = restored_initial_value 921 v = variable_scope.variable( 922 initial_value, name=name, trainable=False, 923 use_resource=resource_variable_ops.is_resource_variable( 924 colocate_with)) 925 # Restore this variable by name if necessary, but don't add a 926 # Trackable dependency. Optimizers return the current graph's 927 # non-slot variables from _checkpoint_dependencies explicitly rather 928 # than unconditionally adding dependencies (since there may be multiple 929 # non-slot variables with the same name in different graphs, trying to 930 # save all of them would result in errors). 931 self._handle_deferred_dependencies(name=name, trackable=v) 932 self._non_slot_dict[key] = v 933 934 return v 935 936 def _trackable_children(self, 937 save_type=trackable.SaveType.CHECKPOINT, 938 **kwargs): 939 """From Trackable. Gather graph-specific non-slot variables to save.""" 940 current_graph_non_slot_variables = {} 941 current_graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 942 for (name, _), variable_object in sorted(self._non_slot_dict.items(), 943 # Avoid comparing graphs 944 key=lambda item: item[0][0]): 945 if variable_object._graph_key == current_graph_key: # pylint: disable=protected-access 946 current_graph_non_slot_variables[name] = variable_object 947 current_graph_non_slot_variables.update( 948 super(Optimizer, self)._trackable_children(save_type, **kwargs)) 949 return current_graph_non_slot_variables 950 951 def _lookup_dependency(self, name): 952 """From Trackable. Find a non-slot variable in the current graph.""" 953 unconditional = super(Optimizer, self)._lookup_dependency(name) 954 if unconditional is not None: 955 return unconditional 956 graph = None if context.executing_eagerly() else ops.get_default_graph() 957 return self._get_non_slot_variable(name, graph=graph) 958 959 def _get_non_slot_variable(self, name, graph=None): 960 non_slot = self._non_slot_dict.get((name, graph), None) 961 if hasattr(non_slot, "_distributed_container"): 962 # This is a mirrored non-slot. In order to enable code like `_finish` 963 # to assign to a non-slot, return the current context replica. 964 return non_slot.get() 965 else: 966 return non_slot 967 968 def _non_slot_variables(self): 969 """Additional variables created by the `Optimizer`. 970 971 Returns: 972 A list or tuple of variables. 973 """ 974 return self._non_slot_dict.values() 975 976 def _assert_valid_dtypes(self, tensors): 977 """Asserts tensors are all valid types (see `_valid_dtypes`). 978 979 Args: 980 tensors: Tensors to check. 981 982 Raises: 983 ValueError: If any tensor is not a valid type. 984 """ 985 valid_dtypes = self._valid_dtypes() 986 for t in tensors: 987 dtype = t.dtype.base_dtype 988 if dtype not in valid_dtypes: 989 raise ValueError( 990 "Invalid type %r for %s, expected: %s." % ( 991 dtype, t.name, [v for v in valid_dtypes])) 992 993 # -------------- 994 # Methods to be implemented by subclasses if they want to use the 995 # inherited implementation of apply_gradients() or compute_gradients(). 996 # -------------- 997 def _valid_dtypes(self): 998 """Valid types for loss, variables and gradients. 999 1000 Subclasses should override to allow other float types. 1001 1002 Returns: 1003 Valid types for loss, variables and gradients. 1004 """ 1005 return set( 1006 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]) 1007 1008 def _create_slots(self, var_list): 1009 """Create all slots needed by the variables. 1010 1011 Args: 1012 var_list: A list of `Variable` objects. 1013 """ 1014 # No slots needed by default 1015 pass 1016 1017 def _prepare(self): 1018 """Create all needed tensors before applying gradients. 1019 1020 This is called with the name_scope using the "name" that 1021 users have chosen for the application of gradients. 1022 """ 1023 pass 1024 1025 def _apply_dense(self, grad, var): 1026 """Add ops to apply dense gradients to `var`. 1027 1028 Args: 1029 grad: A `Tensor`. 1030 var: A `Variable` object. 1031 1032 Returns: 1033 An `Operation`. 1034 """ 1035 raise NotImplementedError() 1036 1037 def _resource_apply_dense(self, grad, handle): 1038 """Add ops to apply dense gradients to the variable `handle`. 1039 1040 Args: 1041 grad: a `Tensor` representing the gradient. 1042 handle: a `Tensor` of dtype `resource` which points to the variable 1043 to be updated. 1044 1045 Returns: 1046 An `Operation` which updates the value of the variable. 1047 """ 1048 raise NotImplementedError() 1049 1050 def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): 1051 """Add ops to apply sparse gradients to `handle`, with repeated indices. 1052 1053 Optimizers which override this method must deal with repeated indices. See 1054 the docstring of `_apply_sparse_duplicate_indices` for details. By default 1055 the correct behavior, to sum non-unique indices and their associated 1056 gradients, is enforced by first pre-processing `grad` and `indices` and 1057 passing them on to `_resource_apply_sparse`. Optimizers which deal correctly 1058 with duplicate indices may instead override this method to avoid the 1059 overhead of summing. 1060 1061 Args: 1062 grad: a `Tensor` representing the gradient for the affected indices. 1063 handle: a `Tensor` of dtype `resource` which points to the variable 1064 to be updated. 1065 indices: a `Tensor` of integral type representing the indices for 1066 which the gradient is nonzero. Indices may be repeated. 1067 1068 Returns: 1069 An `Operation` which updates the value of the variable. 1070 """ 1071 summed_grad, unique_indices = _deduplicate_indexed_slices( 1072 values=grad, indices=indices) 1073 return self._resource_apply_sparse(summed_grad, handle, unique_indices) 1074 1075 def _resource_apply_sparse(self, grad, handle, indices): 1076 """Add ops to apply sparse gradients to the variable `handle`. 1077 1078 Similar to `_apply_sparse`, the `indices` argument to this method has been 1079 de-duplicated. Optimizers which deal correctly with non-unique indices may 1080 instead override `_resource_apply_sparse_duplicate_indices` to avoid this 1081 overhead. 1082 1083 Args: 1084 grad: a `Tensor` representing the gradient for the affected indices. 1085 handle: a `Tensor` of dtype `resource` which points to the variable 1086 to be updated. 1087 indices: a `Tensor` of integral type representing the indices for 1088 which the gradient is nonzero. Indices are unique. 1089 1090 Returns: 1091 An `Operation` which updates the value of the variable. 1092 """ 1093 raise NotImplementedError() 1094 1095 def _apply_sparse_duplicate_indices(self, grad, var): 1096 """Add ops to apply sparse gradients to `var`, with repeated sparse indices. 1097 1098 Optimizers which override this method must deal with IndexedSlices objects 1099 such as the following: 1100 1101 IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1]) 1102 1103 The correct interpretation is: 1104 1105 IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1]) 1106 1107 Many optimizers deal incorrectly with repeated indices when updating based 1108 on sparse gradients (e.g. summing squares rather than squaring the sum, or 1109 applying momentum terms multiple times). Adding first is always the correct 1110 behavior, so this is enforced here by reconstructing the IndexedSlices to 1111 have only unique indices, then calling _apply_sparse. 1112 1113 Optimizers which deal correctly with repeated indices may instead override 1114 this method to avoid the overhead of summing indices. 1115 1116 Args: 1117 grad: `IndexedSlices`. 1118 var: A `Variable` object. 1119 1120 Returns: 1121 An `Operation`. 1122 """ 1123 summed_values, unique_indices = _deduplicate_indexed_slices( 1124 values=grad.values, indices=grad.indices) 1125 gradient_no_duplicate_indices = indexed_slices.IndexedSlices( 1126 indices=unique_indices, 1127 values=summed_values, 1128 dense_shape=grad.dense_shape) 1129 return self._apply_sparse(gradient_no_duplicate_indices, var) 1130 1131 def _apply_sparse(self, grad, var): 1132 """Add ops to apply sparse gradients to `var`. 1133 1134 The IndexedSlices object passed to `grad` in this function is by default 1135 pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate 1136 indices (see its docstring for details). Optimizers which can tolerate or 1137 have correct special cases for duplicate sparse indices may override 1138 `_apply_sparse_duplicate_indices` instead of this function, avoiding that 1139 overhead. 1140 1141 Args: 1142 grad: `IndexedSlices`, with no repeated indices. 1143 var: A `Variable` object. 1144 1145 Returns: 1146 An `Operation`. 1147 """ 1148 raise NotImplementedError() 1149 1150 def _finish(self, update_ops, name_scope): 1151 """Do what is needed to finish the update. 1152 1153 This is called with the `name_scope` using the "name" that 1154 users have chosen for the application of gradients. 1155 1156 Args: 1157 update_ops: List of `Operation` objects to update variables. This list 1158 contains the values returned by the `_apply_dense()` and 1159 `_apply_sparse()` calls. 1160 name_scope: String. Name to use for the returned operation. 1161 1162 Returns: 1163 The operation to apply updates. 1164 """ 1165 return control_flow_ops.group(*update_ops, name=name_scope) 1166 1167 # -------------- 1168 # Utility methods for subclasses. 1169 # -------------- 1170 1171 def _slot_dict(self, slot_name): 1172 """Returns a dict for caching slots created under the given name. 1173 1174 Args: 1175 slot_name: Name for the slot. 1176 1177 Returns: 1178 A dict that maps primary `Variable` objects to the slot created 1179 for that variable, under the given slot name. 1180 """ 1181 named_slots = self._slots.get(slot_name, None) 1182 if named_slots is None: 1183 named_slots = {} 1184 self._slots[slot_name] = named_slots 1185 return named_slots 1186 1187 def _get_or_make_slot(self, var, val, slot_name, op_name): 1188 """Find or create a slot for a variable. 1189 1190 Args: 1191 var: A `Variable` object. 1192 val: A `Tensor`. The initial value of the slot. 1193 slot_name: Name for the slot. 1194 op_name: Name to use when scoping the Variable that 1195 needs to be created for the slot. 1196 1197 Returns: 1198 A `Variable` object. 1199 """ 1200 named_slots = self._slot_dict(slot_name) 1201 if _var_key(var) not in named_slots: 1202 new_slot_variable = slot_creator.create_slot( 1203 var, val, op_name, copy_xla_sharding=True) 1204 self._restore_slot_variable( 1205 slot_name=slot_name, variable=var, 1206 slot_variable=new_slot_variable) 1207 named_slots[_var_key(var)] = new_slot_variable 1208 return named_slots[_var_key(var)] 1209 1210 def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, 1211 slot_name, op_name): 1212 """Find or create a slot for a variable, using an Initializer. 1213 1214 Args: 1215 var: A `Variable` object. 1216 initializer: An `Initializer`. The initial value of the slot. 1217 shape: Shape of the initial value of the slot. 1218 dtype: Type of the value of the slot. 1219 slot_name: Name for the slot. 1220 op_name: Name to use when scoping the Variable that 1221 needs to be created for the slot. 1222 1223 Returns: 1224 A `Variable` object. 1225 """ 1226 named_slots = self._slot_dict(slot_name) 1227 if _var_key(var) not in named_slots: 1228 new_slot_variable = slot_creator.create_slot_with_initializer( 1229 var, initializer, shape, dtype, op_name, copy_xla_sharding=True) 1230 self._restore_slot_variable( 1231 slot_name=slot_name, variable=var, 1232 slot_variable=new_slot_variable) 1233 named_slots[_var_key(var)] = new_slot_variable 1234 return named_slots[_var_key(var)] 1235 1236 def _zeros_slot(self, var, slot_name, op_name): 1237 """Find or create a slot initialized with 0.0. 1238 1239 Args: 1240 var: A `Variable` object. 1241 slot_name: Name for the slot. 1242 op_name: Name to use when scoping the Variable that 1243 needs to be created for the slot. 1244 1245 Returns: 1246 A `Variable` object. 1247 """ 1248 named_slots = self._slot_dict(slot_name) 1249 if _var_key(var) not in named_slots: 1250 new_slot_variable = slot_creator.create_zeros_slot( 1251 var, op_name, copy_xla_sharding=True) 1252 self._restore_slot_variable( 1253 slot_name=slot_name, variable=var, 1254 slot_variable=new_slot_variable) 1255 named_slots[_var_key(var)] = new_slot_variable 1256 return named_slots[_var_key(var)] 1257 1258 # -------------- 1259 # For implementing the Trackable interface. 1260 # -------------- 1261 1262 def _restore_slot_variable(self, slot_name, variable, slot_variable): 1263 """Restore a newly created slot variable's value.""" 1264 variable_key = _var_key(variable) 1265 deferred_restorations = self._deferred_slot_restorations.get( 1266 slot_name, {}).pop(variable_key, []) 1267 # Iterate over restores, highest restore UID first to minimize the number 1268 # of assignments. 1269 deferred_restorations.sort(key=lambda position: position.restore_uid, 1270 reverse=True) 1271 for checkpoint_position in deferred_restorations: 1272 checkpoint_position.restore(slot_variable) 1273 1274 def _create_or_restore_slot_variable( 1275 self, slot_variable_position, slot_name, variable): 1276 """Restore a slot variable's value, possibly creating it. 1277 1278 Called when a variable which has an associated slot variable is created or 1279 restored. When executing eagerly, we create the slot variable with a 1280 restoring initializer. 1281 1282 No new variables are created when graph building. Instead, 1283 _restore_slot_variable catches these after normal creation and adds restore 1284 ops to the graph. This method is nonetheless important when graph building 1285 for the case when a slot variable has already been created but `variable` 1286 has just been added to a dependency graph (causing us to realize that the 1287 slot variable needs to be restored). 1288 1289 Args: 1290 slot_variable_position: A `trackable._CheckpointPosition` object 1291 indicating the slot variable `Trackable` object to be restored. 1292 slot_name: The name of this `Optimizer`'s slot to restore into. 1293 variable: The variable object this slot is being created for. 1294 """ 1295 named_slots = self._slot_dict(slot_name) 1296 variable_key = _var_key(variable) 1297 slot_variable = named_slots.get(variable_key, None) 1298 if (slot_variable is None and context.executing_eagerly() and 1299 slot_variable_position.is_simple_variable() 1300 # Defer slot variable creation if there is an active variable creator 1301 # scope. Generally we'd like to eagerly create/restore slot variables 1302 # when possible, but this may mean that scopes intended to catch 1303 # `variable` also catch its eagerly created slot variable 1304 # unintentionally (specifically make_template would add a dependency on 1305 # a slot variable if not for this case). Deferring is mostly harmless 1306 # (aside from double initialization), and makes variable creator scopes 1307 # behave the same way they do when graph building. 1308 and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access 1309 initializer = trackable.CheckpointInitialValueCallable( 1310 checkpoint_position=slot_variable_position) 1311 # CheckpointInitialValueCallable will ignore the shape and dtype 1312 # parameters but they must be passed. 1313 slot_variable = self._get_or_make_slot_with_initializer( 1314 var=variable, 1315 initializer=initializer, 1316 shape=variable.shape, 1317 dtype=variable.dtype, 1318 slot_name=slot_name, 1319 op_name=self._name) 1320 # Slot variables are not owned by any one object (because we don't want to 1321 # save the slot variable if the optimizer is saved without the non-slot 1322 # variable, or if the non-slot variable is saved without the optimizer; 1323 # it's a dependency hypergraph with edges of the form (optimizer, non-slot 1324 # variable, variable)). So we don't _track_ slot variables anywhere, and 1325 # instead special-case this dependency and otherwise pretend it's a normal 1326 # graph. 1327 if slot_variable is not None: 1328 # If we've either made this slot variable, or if we've pulled out an 1329 # existing slot variable, we should restore it. 1330 slot_variable_position.restore(slot_variable) 1331 else: 1332 # We didn't make the slot variable. Defer restoring until it gets created 1333 # normally. We keep a list rather than the one with the highest restore 1334 # UID in case slot variables have their own dependencies, in which case 1335 # those could differ between restores. 1336 self._deferred_slot_restorations.setdefault( 1337 slot_name, {}).setdefault(variable_key, []).append( 1338 slot_variable_position) 1339 1340 def _call_if_callable(self, param): 1341 """Call the function if param is callable.""" 1342 return param() if callable(param) else param 1343