1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Maintain moving averages of parameters.""" 16from tensorflow.python.distribute import distribute_lib 17from tensorflow.python.distribute import distribution_strategy_context 18from tensorflow.python.distribute import reduce_util as ds_reduce_util 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import control_flow_ops 22from tensorflow.python.ops import init_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import state_ops 25from tensorflow.python.ops import variable_scope 26from tensorflow.python.ops import variables 27from tensorflow.python.training import slot_creator 28from tensorflow.python.util.tf_export import tf_export 29from tensorflow.tools.docs import doc_controls 30 31 32@tf_export("__internal__.train.assign_moving_average", v1=[]) 33def assign_moving_average(variable, value, decay, zero_debias=True, name=None): 34 """Compute the moving average of a variable. 35 36 The moving average of 'variable' updated with 'value' is: 37 variable * decay + value * (1 - decay) 38 39 The returned Operation sets 'variable' to the newly computed moving average, 40 by performing this subtraction: 41 variable -= (1 - decay) * (variable - value) 42 43 Since variables that are initialized to a `0` value will be `0` biased, 44 `zero_debias` optionally enables scaling by the mathematically correct 45 debiasing factor of 46 1 - decay ** num_updates 47 See Section 3 of (Kingma et al., 2015) for more details. 48 49 The names of the debias shadow variables, by default, include both the scope 50 they were created in and the scope of the variables they debias. They are also 51 given a uniquifying-suffix. 52 53 E.g.: 54 55 ``` 56 with tf.compat.v1.variable_scope('scope1'): 57 with tf.compat.v1.variable_scope('scope2'): 58 var = tf.compat.v1.get_variable('foo') 59 update_1 = tf.assign_moving_average(var, 0.0, 1.0) 60 update_2 = tf.assign_moving_average(var, 0.0, 0.9) 61 62 # var.name: 'scope1/scope2/foo' 63 # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' 64 # 'scope1/scope2/scope1/scope2/foo/biased_1' 65 ``` 66 67 Args: 68 variable: A Variable. 69 value: A tensor with the same shape as 'variable'. 70 decay: A float `Tensor` or float value. The moving average decay. 71 zero_debias: A python bool. If true, assume the variable is 0-initialized 72 and unbias it, as in (Kingma et al., 2015). See docstring in 73 `_zero_debias` for more details. 74 name: Optional name of the returned operation. 75 76 Returns: 77 A tensor which if evaluated will compute and return the new moving average. 78 79 References: 80 Adam - A Method for Stochastic Optimization: 81 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 82 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 83 """ 84 with ops.name_scope(name, "AssignMovingAvg", 85 [variable, value, decay]) as scope: 86 decay = ops.convert_to_tensor(1.0 - decay, name="decay") 87 if decay.dtype != variable.dtype.base_dtype: 88 decay = math_ops.cast(decay, variable.dtype.base_dtype) 89 90 def update_fn(v, value): 91 return state_ops.assign_sub(v, (v - value) * decay, name=scope) 92 93 def update(strategy, v, value): 94 if zero_debias: 95 return _zero_debias(strategy, v, value, decay) 96 else: 97 return _update(strategy, v, update_fn, args=(value,)) 98 99 replica_context = distribution_strategy_context.get_replica_context() 100 if replica_context: 101 # In a replica context, we update variable using the mean of value across 102 # replicas. 103 def merge_fn(strategy, v, value): 104 value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value, 105 v) 106 return update(strategy, v, value) 107 108 return replica_context.merge_call(merge_fn, args=(variable, value)) 109 else: 110 strategy = distribution_strategy_context.get_cross_replica_context() 111 return update(strategy, variable, value) 112 113 114def weighted_moving_average(value, 115 decay, 116 weight, 117 truediv=True, 118 collections=None, 119 name=None): 120 """Compute the weighted moving average of `value`. 121 122 Conceptually, the weighted moving average is: 123 `moving_average(value * weight) / moving_average(weight)`, 124 where a moving average updates by the rule 125 `new_value = decay * old_value + (1 - decay) * update` 126 Internally, this Op keeps moving average variables of both `value * weight` 127 and `weight`. 128 129 Args: 130 value: A numeric `Tensor`. 131 decay: A float `Tensor` or float value. The moving average decay. 132 weight: `Tensor` that keeps the current value of a weight. Shape should be 133 able to multiply `value`. 134 truediv: Boolean, if `True`, dividing by `moving_average(weight)` is 135 floating point division. If `False`, use division implied by dtypes. 136 collections: List of graph collections keys to add the internal variables 137 `value * weight` and `weight` to. Defaults to 138 `[GraphKeys.GLOBAL_VARIABLES]`. 139 name: Optional name of the returned operation. Defaults to 140 "WeightedMovingAvg". 141 142 Returns: 143 An Operation that updates and returns the weighted moving average. 144 """ 145 # Unlike assign_moving_average, the weighted moving average doesn't modify 146 # user-visible variables. It is the ratio of two internal variables, which are 147 # moving averages of the updates. Thus, the signature of this function is 148 # quite different than assign_moving_average. 149 if collections is None: 150 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 151 with variable_scope.variable_scope(name, "WeightedMovingAvg", 152 [value, weight, decay]) as scope: 153 value_x_weight_var = variable_scope.get_variable( 154 "value_x_weight", 155 shape=value.get_shape(), 156 dtype=value.dtype, 157 initializer=init_ops.zeros_initializer(), 158 trainable=False, 159 collections=collections) 160 weight_var = variable_scope.get_variable( 161 "weight", 162 shape=weight.get_shape(), 163 dtype=weight.dtype, 164 initializer=init_ops.zeros_initializer(), 165 trainable=False, 166 collections=collections) 167 numerator = assign_moving_average( 168 value_x_weight_var, value * weight, decay, zero_debias=False) 169 denominator = assign_moving_average( 170 weight_var, weight, decay, zero_debias=False) 171 172 if truediv: 173 return math_ops.truediv(numerator, denominator, name=scope.name) 174 else: 175 return math_ops.divide(numerator, denominator, name=scope.name) 176 177 178def _update(strategy, var, update_fn, args): 179 """Applies updates depending on the context.""" 180 assert distribution_strategy_context.in_cross_replica_context(), ( 181 "_update can only be called in cross-replica context") 182 if distribute_lib.get_update_replica_id() is not None: 183 # Call update_fn on var to delegate the implementation. We expect `var` will 184 # do the right thing in update context, e.g, if `var` is a MirroredVariable, 185 # it should pick its component variable based on `update_replica_id` and 186 # only update that. 187 return update_fn(var, *args) 188 else: 189 return strategy.extended.update(var, update_fn, args) 190 191 192def _zero_debias(strategy, unbiased_var, value, decay): 193 """Compute the delta required for a debiased Variable. 194 195 All exponential moving averages initialized with Tensors are initialized to 0, 196 and therefore are biased to 0. Variables initialized to 0 and used as EMAs are 197 similarly biased. This function creates the debias updated amount according to 198 a scale factor, as in (Kingma et al., 2015). 199 200 To demonstrate the bias the results from 0-initialization, take an EMA that 201 was initialized to `0` with decay `b`. After `t` timesteps of seeing the 202 constant `c`, the variable have the following value: 203 204 ``` 205 EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ... 206 = c*(1 - b^t) 207 ``` 208 209 To have the true value `c`, we would divide by the scale factor `1 - b^t`. 210 211 In order to perform debiasing, we use two shadow variables. One keeps track of 212 the biased estimate, and the other keeps track of the number of updates that 213 have occurred. 214 215 Args: 216 strategy: `Strategy` used to create and update variables. 217 unbiased_var: A Variable representing the current value of the unbiased EMA. 218 value: A Tensor representing the most recent value. 219 decay: A Tensor representing `1-decay` for the EMA. 220 221 Returns: 222 The amount that the unbiased variable should be updated. Computing this 223 tensor will also update the shadow variables appropriately. 224 225 References: 226 Adam - A Method for Stochastic Optimization: 227 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 228 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 229 230 """ 231 with variable_scope.variable_scope( 232 unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]): 233 with ops.init_scope(): 234 biased_initializer = init_ops.zeros_initializer() 235 local_step_initializer = init_ops.zeros_initializer() 236 237 def _maybe_get_unique(name): 238 """Get name for a unique variable, if not `reuse=True`.""" 239 if variable_scope.get_variable_scope().reuse: 240 return name 241 vs_vars = [ 242 x.op.name 243 for x in variable_scope.get_variable_scope().global_variables() 244 ] 245 full_name = variable_scope.get_variable_scope().name + "/" + name 246 if full_name not in vs_vars: 247 return name 248 idx = 1 249 while full_name + ("_%d" % idx) in vs_vars: 250 idx += 1 251 return name + ("_%d" % idx) 252 253 with strategy.extended.colocate_vars_with(unbiased_var): 254 biased_var = variable_scope.get_variable( 255 _maybe_get_unique("biased"), 256 initializer=biased_initializer, 257 shape=unbiased_var.get_shape(), 258 dtype=unbiased_var.dtype, 259 trainable=False) 260 local_step = variable_scope.get_variable( 261 _maybe_get_unique("local_step"), 262 shape=[], 263 dtype=unbiased_var.dtype, 264 initializer=local_step_initializer, 265 trainable=False) 266 267 def update_fn(v, value, biased_var, local_step): 268 update_biased = state_ops.assign_sub(biased_var, 269 (biased_var - value) * decay) 270 update_local_step = local_step.assign_add(1) 271 272 # This function gets `1 - decay`, so use `1.0 - decay` in the exponent. 273 bias_factor = 1 - math_ops.pow(1.0 - decay, update_local_step) 274 return state_ops.assign( 275 v, update_biased / bias_factor, name=ops.get_name_scope() + "/") 276 277 return _update( 278 strategy, unbiased_var, update_fn, args=(value, biased_var, local_step)) 279 280 281@tf_export("train.ExponentialMovingAverage") 282class ExponentialMovingAverage: 283 """Maintains moving averages of variables by employing an exponential decay. 284 285 When training a model, it is often beneficial to maintain moving averages of 286 the trained parameters. Evaluations that use averaged parameters sometimes 287 produce significantly better results than the final trained values. 288 289 The `apply()` method adds shadow copies of trained variables the first time 290 it is called, and maintains a moving average of the trained variables in 291 their shadow copies at every additional invocation. 292 It should generally be called immediately after creating the model weights, 293 and then after each training step. 294 295 The `average()` method gives access to the shadow variables. 296 It allows you to use the moving averages in place of the last trained values 297 for evaluations, by loading the moving averages into your model via 298 `var.assign(ema.average(var))`. 299 Additionally, although `ExponentialMovingAverage` 300 objects are not directly trackable by checkpoints, 301 `average()` returns the moving average variables for your model weights, 302 which you can then checkpoint. (There is an example 303 of this near the bottom of this docstring). 304 So, `average()` is useful when 305 building an evaluation model, or when restoring a model from a checkpoint 306 file. 307 308 The moving averages are computed using exponential decay. You specify the 309 decay value (as a scalar float value, `Tensor`, or `Variable`) when creating 310 the `ExponentialMovingAverage` object. The shadow variables are initialized 311 with the same initial values as the trained variables. When you run `apply` 312 to update the moving averages, each shadow variable is updated with the 313 formula: 314 315 `shadow_variable -= (1 - decay) * (shadow_variable - variable)` 316 317 This is mathematically equivalent to the classic formula below, but the use 318 of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless 319 updates to the variables: 320 321 `shadow_variable = decay * shadow_variable + (1 - decay) * variable` 322 323 Reasonable values for `decay` are close to 1.0, typically in the 324 multiple-nines range: 0.999, 0.9999, etc. 325 326 To have fine-grained control over the value of the decay parameter during 327 training, pass a scalar `tf.Variable` as the `decay` value to the constructor, 328 and update the variable as needed. 329 330 Example usage when creating a training model: 331 332 ```python 333 # Create variables. 334 var0 = tf.Variable(...) 335 var1 = tf.Variable(...) 336 # ... use the variables to build a training model... 337 338 # Create an ExponentialMovingAverage object 339 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 340 341 # The first `apply` creates the shadow variables that hold the moving averages 342 ema.apply([var0, var1]) 343 344 # grab the moving averages for checkpointing purposes or to be able to 345 # load the moving averages into the model weights 346 averages = [ema.average(var0), ema.average(var1)] 347 348 ... 349 def train_step(...): 350 ... 351 # Apply the optimizer. 352 opt.minimize(my_loss, [var0, var1]) 353 354 # Update the moving averages 355 # of var0 and var1 with additional calls to `apply` 356 ema.apply([var0, var1]) 357 358 ...train the model by running train_step multiple times... 359 ``` 360 361 There are several ways to use the moving averages for evaluations: 362 363 1. Assign the values of the shadow variables to your model variables with 364 `Variable.assign(...)` before evaluating your 365 model. You can use the `average()` 366 method to get the shadow variable for a given variable. To continue 367 training after using this approach, make sure to record the unaveraged 368 weights and restore them before continuing to train. You can see the 369 tensorflow-addons' MovingAverage optimizer's `swap_weights` method for 370 one example of how to swap variables efficiently in distributed settings: 371 https://github.com/tensorflow/addons/blob/v0.13.0/tensorflow_addons/optimizers/moving_average.py#L151 372 2. Make sure to checkpoint out your moving average variables in your 373 `tf.train.Checkpoint`. At evaluation time, create your shadow variables and 374 use `tf.train.Checkpoint` to restore the moving averages into the shadow 375 variables. Then, load the moving averages into the actual model weights via 376 `var.assign(moving_avg)`. 377 3. Checkpoint out your moving average variables in your `tf.train.Checkpoint`. 378 For evaluation, restore your model weights directly from the moving 379 averages instead of from the non-averaged weights. 380 Caution: If you choose this approach, include only the object-graph paths 381 to the averaged path in your checkpoint restore. 382 If you point both the unaveraged and averaged paths in a checkpoint 383 restore to the same variables, it is hard to reason about whether your 384 model will restore the averaged or non-averaged variables. 385 386 Example of saving out then restoring the shadow variable values: 387 388 ```python 389 # Create variables. 390 var0 = tf.Variable(...) 391 var1 = tf.Variable(...) 392 # ... use the variables to build a training model... 393 394 # Create an ExponentialMovingAverage object, create the shadow variables, 395 # and grab the moving averages for checkpointing purposes. 396 # (The ExponentialMovingAverage object itself is not checkpointable) 397 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 398 ema.apply([var0, var1]) 399 avg_var0 = ema.average(var0) 400 avg_var1 = ema.average(var1) 401 402 # Create a Checkpoint that will manage the model weights and the averages, 403 checkpoint = tf.train.Checkpoint(model_weights=[var0, var1], 404 averaged_weights=[avg_var0, avg_var1]) 405 ... # Do training 406 407 # Save out the checkpoint including the model weights and the moving averages 408 checkpoint.save(...) 409 ``` 410 411 Restore option: restore all averaged & non-averaged weights, then load 412 moving averages into the model via `var.assign()` 413 ```python 414 # Create variables. 415 var0 = tf.Variable(...) 416 var1 = tf.Variable(...) 417 # ... use the variables to build a training model... 418 419 # Create an ExponentialMovingAverage object, create the shadow variables, 420 # and grab the moving averages for checkpoint restore purposes. 421 # (The ExponentialMovingAverage object itself is not checkpointable) 422 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 423 ema.apply([var0, var1]) 424 avg_var0 = ema.average(var0) 425 avg_var1 = ema.average(var1) 426 427 # Create a Checkpoint that will manage the model weights and the averages, 428 checkpoint = tf.train.Checkpoint(model_weights=[var0, var1], 429 averaged_weights=[avg_var0, avg_var1]) 430 checkpoint.restore(...) 431 var0.assign(avg_var0) 432 var1.assign(avg_var1) 433 # var0 and var1 now hold the moving average values 434 ``` 435 436 Restore option: Directly restore the moving averages into the model weights. 437 ```python 438 # Create variables. 439 var0 = tf.Variable(...) 440 var1 = tf.Variable(...) 441 # ... use the variables to build a training model... 442 443 # Create a Checkpoint that will manage two objects with trackable state, 444 checkpoint = tf.train.Checkpoint(averaged_weights=[var0, var1]) 445 checkpoint.restore(...) 446 # var0 and var1 now hold the moving average values 447 ``` 448 """ 449 450 def __init__(self, 451 decay, 452 num_updates=None, 453 zero_debias=False, 454 name="ExponentialMovingAverage"): 455 """Creates a new ExponentialMovingAverage object. 456 457 The `apply()` method has to be called to create shadow variables. 458 Follow-on calls to the `apply()` method will update the moving averages 459 in the shadow variables. 460 (In TF 1.x graphs `apply()` will return an update op to update 461 the moving averages which must be explicitly run). 462 463 The optional `num_updates` parameter allows one to tweak the decay rate 464 dynamically. It is typical to pass the count of training steps, usually 465 kept in a variable that is incremented at each step, in which case the 466 decay rate is lower at the start of training. This makes moving averages 467 move faster. If passed, the actual decay rate used is: 468 469 `min(decay, (1 + num_updates) / (10 + num_updates))` 470 471 Args: 472 decay: A scalar float value, `Tensor`, or `Variable`. The decay parameter. 473 num_updates: Optional count of number of updates applied to variables. 474 zero_debias: If `True`, zero debias moving-averages that are initialized 475 with tensors. (Note: moving averages may not be initialized with 476 non-variable tensors when eager execution is enabled). 477 name: String. Optional prefix name to use for the name of ops added in 478 `apply()`. 479 """ 480 self._decay = decay 481 self._num_updates = num_updates 482 self._zero_debias = zero_debias 483 self._name = name 484 self._averages = {} 485 486 @property 487 def name(self): 488 """The name of this ExponentialMovingAverage object.""" 489 return self._name 490 491 def apply(self, var_list=None): 492 """Maintains moving averages of variables. 493 494 `var_list` must be a list of `Variable` objects. This method 495 creates shadow variables (holding the moving averages) 496 for all elements of `var_list`, and 497 updates the moving averages using the current `var_list` values. Shadow 498 variables for `Variable` objects are initialized to the variable's initial 499 value. 500 501 Shadow variables are created with `trainable=False`. To access them you 502 can use the EMA object's `average` method. Note that `EMA` objects are 503 not trackable by checkpoints, so if you want to checkpoint or restore the 504 moving variables you will need to manually grab the shadow 505 variables via `average()` and assign them as `tf.Module` properties or 506 directly pass them to your `tf.train.Checkpoint`. 507 508 Note that `apply()` can be called multiple times. When eager execution is 509 enabled each call to apply will update the variables once, so this needs to 510 be called in a loop. 511 512 In legacy TF 1.x graphs, this method returns an op that updates all 513 shadow variables from the current value of their associated variables. In 514 TF 1.x graphs without automatically control dependencies this op needs to be 515 manually run. 516 517 Args: 518 var_list: A list of Variable objects. The variables 519 must be of types bfloat16, float16, float32, or float64. 520 (In legacy TF 1.x graphs these may be tensors, but this is unsupported 521 when eager execution is enabled.) 522 523 Returns: 524 An Operation that updates the moving averages. 525 526 Raises: 527 TypeError: If the arguments are not an allowed type. 528 """ 529 # TODO(touts): op_scope 530 if var_list is None: 531 var_list = variables.trainable_variables() 532 for v in var_list: 533 if (isinstance(v, ops.Tensor) 534 and ops.executing_eagerly_outside_functions()): 535 raise TypeError( 536 "tf.train.ExponentialMovingAverage does not support non-Variable" 537 " tensors when eager execution is enabled.") 538 zero_debias_true = set() # set of vars to set `zero_debias=True` 539 for var in var_list: 540 if var.dtype.base_dtype not in [ 541 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 542 ]: 543 raise TypeError("The variables must be half, float, or double: %s" % 544 var.name) 545 546 if var.ref() not in self._averages: 547 # For variables: to lower communication bandwidth across devices we keep 548 # the moving averages on the same device as the variables. For other 549 # tensors, we rely on the existing device allocation mechanism. 550 with ops.init_scope(): 551 if isinstance(var, variables.Variable): 552 with ops.device(var.device): 553 initialized_value = var.initialized_value() 554 avg = slot_creator.create_slot( 555 var, 556 initialized_value, 557 self.name, 558 colocate_with_primary=True, 559 copy_xla_sharding=True) 560 # NOTE(mrry): We only add `tf.Variable` objects to the 561 # `MOVING_AVERAGE_VARIABLES` collection. 562 ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 563 else: 564 avg = slot_creator.create_zeros_slot( 565 var, 566 self.name, 567 colocate_with_primary=(var.op.type in [ 568 "Variable", "VariableV2", "VarHandleOp" 569 ]), 570 copy_xla_sharding=True) 571 if self._zero_debias: 572 zero_debias_true.add(avg.ref()) 573 self._averages[var.ref()] = avg 574 575 with ops.name_scope(self.name) as scope: 576 decay = ops.convert_to_tensor( 577 self._decay, dtype=dtypes.float32, name="decay") 578 if self._num_updates is not None: 579 num_updates = math_ops.cast( 580 self._num_updates, dtypes.float32, name="num_updates") 581 decay = math_ops.minimum(decay, 582 (1.0 + num_updates) / (10.0 + num_updates)) 583 updates = [] 584 for var in var_list: 585 avg = self._averages[var.ref()] 586 zero_debias = avg.ref() in zero_debias_true 587 updates.append(assign_moving_average(avg, var, decay, zero_debias)) 588 return control_flow_ops.group(*updates, name=scope) 589 590 def average(self, var): 591 """Returns the `Variable` holding the average of `var`. 592 593 Args: 594 var: A `Variable` object. 595 596 Returns: 597 A `Variable` object or `None` if the moving average of `var` 598 is not maintained. 599 """ 600 return self._averages.get(var.ref(), None) 601 602 @doc_controls.do_not_generate_docs 603 def average_name(self, var): 604 """[Meant for TF1] Returns name of `Variable` holding the average for `var`. 605 606 (Designed to work with legacy `tf.compat.v1.train.Saver`, it is sensitive to 607 specific variable names and not recommended for TF2) 608 609 The typical scenario for `ExponentialMovingAverage` is to compute moving 610 averages of variables during training, and restore the variables from the 611 computed moving averages during evaluations. 612 613 To restore variables, you have to know the name of the shadow variables. 614 That name and the original variable can then be passed to a `Saver()` object 615 to restore the variable from the moving average value with: 616 `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})` 617 618 `average_name()` can be called whether or not `apply()` has been called. 619 620 Args: 621 var: A `Variable` object. 622 623 Returns: 624 A string: The name of the variable that will be used or was used 625 by the `ExponentialMovingAverage class` to hold the moving average of 626 `var`. 627 """ 628 if var.ref() in self._averages: 629 return self._averages[var.ref()].name[:-len(":0")] 630 return ops.get_default_graph().unique_name( 631 var.name[:-len(":0")] + "/" + self.name, mark_as_used=False) 632 633 @doc_controls.do_not_generate_docs 634 def variables_to_restore(self, moving_avg_variables=None): 635 """[Designed for TF 1.x] Returns a map of names to `Variables` to restore. 636 637 (Designed to work with legacy `tf.compat.v1.train.Saver`, sensitive to 638 specific variable names and not recommended for TF2) 639 640 If a variable has a moving average, use the moving average variable name as 641 the restore name; otherwise, use the variable name. 642 643 For example, 644 645 ```python 646 variables_to_restore = ema.variables_to_restore() 647 saver = tf.compat.v1.train.Saver(variables_to_restore) 648 ``` 649 650 Below is an example of such mapping: 651 652 ``` 653 conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, 654 conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, 655 global_step: global_step 656 ``` 657 658 Args: 659 moving_avg_variables: a list of variables that require to use of the 660 moving average variable name to be restored. If None, it will default to 661 variables.moving_average_variables() + variables.trainable_variables() 662 663 Returns: 664 A map from restore_names to variables. The restore_name is either the 665 original or the moving average version of the variable name, depending 666 on whether the variable name is in the `moving_avg_variables`. 667 """ 668 name_map = {} 669 if moving_avg_variables is None: 670 # Include trainable variables and variables which have been explicitly 671 # added to the moving_average_variables collection. 672 moving_avg_variables = variables.trainable_variables() 673 moving_avg_variables += variables.moving_average_variables() 674 # Remove duplicates 675 moving_avg_variables = set(v.ref() for v in moving_avg_variables) 676 # Collect all the variables with moving average, 677 for v in moving_avg_variables: 678 name_map[self.average_name(v.deref())] = v.deref() 679 # Make sure we restore variables without moving averages as well. 680 moving_avg_variable_names = set( 681 v.deref().name for v in moving_avg_variables) 682 for v in list(set(variables.global_variables())): 683 if v.name not in moving_avg_variable_names and v.op.name not in name_map: 684 name_map[v.op.name] = v 685 return name_map 686