1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for training.""" 16from tensorflow.python.eager import context 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import graph_io 19from tensorflow.python.framework import ops 20from tensorflow.python.ops import init_ops 21from tensorflow.python.ops import resource_variable_ops 22from tensorflow.python.ops import state_ops 23from tensorflow.python.ops import variable_scope 24from tensorflow.python.ops import variables 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.util.tf_export import tf_export 27 28# Picked a long key value to minimize the chance of collision with user defined 29# collection keys. 30GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' 31 32# TODO(drpng): remove this after legacy uses are resolved. 33write_graph = graph_io.write_graph 34 35 36@tf_export(v1=['train.global_step']) 37def global_step(sess, global_step_tensor): 38 """Small helper to get the global step. 39 40 ```python 41 # Create a variable to hold the global_step. 42 global_step_tensor = tf.Variable(10, trainable=False, name='global_step') 43 # Create a session. 44 sess = tf.compat.v1.Session() 45 # Initialize the variable 46 sess.run(global_step_tensor.initializer) 47 # Get the variable value. 48 print('global_step: %s' % tf.compat.v1.train.global_step(sess, 49 global_step_tensor)) 50 51 global_step: 10 52 ``` 53 54 Args: 55 sess: A TensorFlow `Session` object. 56 global_step_tensor: `Tensor` or the `name` of the operation that contains 57 the global step. 58 59 Returns: 60 The global step value. 61 """ 62 if context.executing_eagerly(): 63 return int(global_step_tensor.numpy()) 64 return int(sess.run(global_step_tensor)) 65 66 67@tf_export(v1=['train.get_global_step']) 68def get_global_step(graph=None): 69 """Get the global step tensor. 70 71 The global step tensor must be an integer variable. We first try to find it 72 in the collection `GLOBAL_STEP`, or by name `global_step:0`. 73 74 Args: 75 graph: The graph to find the global step in. If missing, use default graph. 76 77 Returns: 78 The global step variable, or `None` if none was found. 79 80 Raises: 81 TypeError: If the global step tensor has a non-integer type, or if it is not 82 a `Variable`. 83 84 @compatibility(TF2) 85 With the deprecation of global graphs, TF no longer tracks variables in 86 collections. In other words, there are no global variables in TF2. Thus, the 87 global step functions have been removed (`get_or_create_global_step`, 88 `create_global_step`, `get_global_step`) . You have two options for migrating: 89 90 1. Create a Keras optimizer, which generates an `iterations` variable. This 91 variable is automatically incremented when calling `apply_gradients`. 92 2. Manually create and increment a `tf.Variable`. 93 94 Below is an example of migrating away from using a global step to using a 95 Keras optimizer: 96 97 Define a dummy model and loss: 98 99 >>> def compute_loss(x): 100 ... v = tf.Variable(3.0) 101 ... y = x * v 102 ... loss = x * 5 - x * v 103 ... return loss, [v] 104 105 Before migrating: 106 107 >>> g = tf.Graph() 108 >>> with g.as_default(): 109 ... x = tf.compat.v1.placeholder(tf.float32, []) 110 ... loss, var_list = compute_loss(x) 111 ... global_step = tf.compat.v1.train.get_or_create_global_step() 112 ... global_init = tf.compat.v1.global_variables_initializer() 113 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 114 ... train_op = optimizer.minimize(loss, global_step, var_list) 115 >>> sess = tf.compat.v1.Session(graph=g) 116 >>> sess.run(global_init) 117 >>> print("before training:", sess.run(global_step)) 118 before training: 0 119 >>> sess.run(train_op, feed_dict={x: 3}) 120 >>> print("after training:", sess.run(global_step)) 121 after training: 1 122 123 Using `get_global_step`: 124 125 >>> with g.as_default(): 126 ... print(sess.run(tf.compat.v1.train.get_global_step())) 127 1 128 129 Migrating to a Keras optimizer: 130 131 >>> optimizer = tf.keras.optimizers.SGD(.01) 132 >>> print("before training:", optimizer.iterations.numpy()) 133 before training: 0 134 >>> with tf.GradientTape() as tape: 135 ... loss, var_list = compute_loss(3) 136 ... grads = tape.gradient(loss, var_list) 137 ... optimizer.apply_gradients(zip(grads, var_list)) 138 >>> print("after training:", optimizer.iterations.numpy()) 139 after training: 1 140 141 @end_compatibility 142 """ 143 graph = graph or ops.get_default_graph() 144 global_step_tensor = None 145 global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) 146 if len(global_step_tensors) == 1: 147 global_step_tensor = global_step_tensors[0] 148 elif not global_step_tensors: 149 try: 150 global_step_tensor = graph.get_tensor_by_name('global_step:0') 151 except KeyError: 152 return None 153 else: 154 logging.error('Multiple tensors in global_step collection.') 155 return None 156 157 assert_global_step(global_step_tensor) 158 return global_step_tensor 159 160 161@tf_export(v1=['train.create_global_step']) 162def create_global_step(graph=None): 163 """Create global step tensor in graph. 164 165 Args: 166 graph: The graph in which to create the global step tensor. If missing, use 167 default graph. 168 169 Returns: 170 Global step tensor. 171 172 Raises: 173 ValueError: if global step tensor is already defined. 174 175 @compatibility(TF2) 176 With the deprecation of global graphs, TF no longer tracks variables in 177 collections. In other words, there are no global variables in TF2. Thus, the 178 global step functions have been removed (`get_or_create_global_step`, 179 `create_global_step`, `get_global_step`) . You have two options for migrating: 180 181 1. Create a Keras optimizer, which generates an `iterations` variable. This 182 variable is automatically incremented when calling `apply_gradients`. 183 2. Manually create and increment a `tf.Variable`. 184 185 Below is an example of migrating away from using a global step to using a 186 Keras optimizer: 187 188 Define a dummy model and loss: 189 190 >>> def compute_loss(x): 191 ... v = tf.Variable(3.0) 192 ... y = x * v 193 ... loss = x * 5 - x * v 194 ... return loss, [v] 195 196 Before migrating: 197 198 >>> g = tf.Graph() 199 >>> with g.as_default(): 200 ... x = tf.compat.v1.placeholder(tf.float32, []) 201 ... loss, var_list = compute_loss(x) 202 ... global_step = tf.compat.v1.train.create_global_step() 203 ... global_init = tf.compat.v1.global_variables_initializer() 204 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 205 ... train_op = optimizer.minimize(loss, global_step, var_list) 206 >>> sess = tf.compat.v1.Session(graph=g) 207 >>> sess.run(global_init) 208 >>> print("before training:", sess.run(global_step)) 209 before training: 0 210 >>> sess.run(train_op, feed_dict={x: 3}) 211 >>> print("after training:", sess.run(global_step)) 212 after training: 1 213 214 Migrating to a Keras optimizer: 215 216 >>> optimizer = tf.keras.optimizers.SGD(.01) 217 >>> print("before training:", optimizer.iterations.numpy()) 218 before training: 0 219 >>> with tf.GradientTape() as tape: 220 ... loss, var_list = compute_loss(3) 221 ... grads = tape.gradient(loss, var_list) 222 ... optimizer.apply_gradients(zip(grads, var_list)) 223 >>> print("after training:", optimizer.iterations.numpy()) 224 after training: 1 225 226 @end_compatibility 227 """ 228 graph = graph or ops.get_default_graph() 229 if get_global_step(graph) is not None: 230 raise ValueError('"global_step" already exists.') 231 if context.executing_eagerly(): 232 with ops.device('cpu:0'): 233 return variable_scope.get_variable( 234 ops.GraphKeys.GLOBAL_STEP, 235 shape=[], 236 dtype=dtypes.int64, 237 initializer=init_ops.zeros_initializer(), 238 trainable=False, 239 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 240 collections=[ 241 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP 242 ]) 243 # Create in proper graph and base name_scope. 244 with graph.as_default() as g, g.name_scope(None): 245 return variable_scope.get_variable( 246 ops.GraphKeys.GLOBAL_STEP, 247 shape=[], 248 dtype=dtypes.int64, 249 initializer=init_ops.zeros_initializer(), 250 trainable=False, 251 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 252 collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) 253 254 255@tf_export(v1=['train.get_or_create_global_step']) 256def get_or_create_global_step(graph=None): 257 """Returns and create (if necessary) the global step tensor. 258 259 Args: 260 graph: The graph in which to create the global step tensor. If missing, use 261 default graph. 262 263 Returns: 264 The global step tensor. 265 266 @compatibility(TF2) 267 With the deprecation of global graphs, TF no longer tracks variables in 268 collections. In other words, there are no global variables in TF2. Thus, the 269 global step functions have been removed (`get_or_create_global_step`, 270 `create_global_step`, `get_global_step`) . You have two options for migrating: 271 272 1. Create a Keras optimizer, which generates an `iterations` variable. This 273 variable is automatically incremented when calling `apply_gradients`. 274 2. Manually create and increment a `tf.Variable`. 275 276 Below is an example of migrating away from using a global step to using a 277 Keras optimizer: 278 279 Define a dummy model and loss: 280 281 >>> def compute_loss(x): 282 ... v = tf.Variable(3.0) 283 ... y = x * v 284 ... loss = x * 5 - x * v 285 ... return loss, [v] 286 287 Before migrating: 288 289 >>> g = tf.Graph() 290 >>> with g.as_default(): 291 ... x = tf.compat.v1.placeholder(tf.float32, []) 292 ... loss, var_list = compute_loss(x) 293 ... global_step = tf.compat.v1.train.get_or_create_global_step() 294 ... global_init = tf.compat.v1.global_variables_initializer() 295 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 296 ... train_op = optimizer.minimize(loss, global_step, var_list) 297 >>> sess = tf.compat.v1.Session(graph=g) 298 >>> sess.run(global_init) 299 >>> print("before training:", sess.run(global_step)) 300 before training: 0 301 >>> sess.run(train_op, feed_dict={x: 3}) 302 >>> print("after training:", sess.run(global_step)) 303 after training: 1 304 305 Migrating to a Keras optimizer: 306 307 >>> optimizer = tf.keras.optimizers.SGD(.01) 308 >>> print("before training:", optimizer.iterations.numpy()) 309 before training: 0 310 >>> with tf.GradientTape() as tape: 311 ... loss, var_list = compute_loss(3) 312 ... grads = tape.gradient(loss, var_list) 313 ... optimizer.apply_gradients(zip(grads, var_list)) 314 >>> print("after training:", optimizer.iterations.numpy()) 315 after training: 1 316 317 @end_compatibility 318 """ 319 graph = graph or ops.get_default_graph() 320 global_step_tensor = get_global_step(graph) 321 if global_step_tensor is None: 322 global_step_tensor = create_global_step(graph) 323 return global_step_tensor 324 325 326@tf_export(v1=['train.assert_global_step']) 327def assert_global_step(global_step_tensor): 328 """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. 329 330 Args: 331 global_step_tensor: `Tensor` to test. 332 """ 333 if not (isinstance(global_step_tensor, variables.Variable) or 334 isinstance(global_step_tensor, ops.Tensor) or 335 resource_variable_ops.is_resource_variable(global_step_tensor)): 336 raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' % 337 global_step_tensor) 338 339 if not global_step_tensor.dtype.base_dtype.is_integer: 340 raise TypeError('Existing "global_step" does not have integer type: %s' % 341 global_step_tensor.dtype) 342 343 if (global_step_tensor.get_shape().ndims != 0 and 344 global_step_tensor.get_shape().is_fully_defined()): 345 raise TypeError('Existing "global_step" is not scalar: %s' % 346 global_step_tensor.get_shape()) 347 348 349def _get_global_step_read(graph=None): 350 """Gets global step read tensor in graph. 351 352 Args: 353 graph: The graph in which to create the global step read tensor. If missing, 354 use default graph. 355 356 Returns: 357 Global step read tensor. 358 359 Raises: 360 RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY. 361 """ 362 graph = graph or ops.get_default_graph() 363 global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY) 364 if len(global_step_read_tensors) > 1: 365 raise RuntimeError('There are multiple items in collection {}. ' 366 'There should be only one.'.format(GLOBAL_STEP_READ_KEY)) 367 368 if len(global_step_read_tensors) == 1: 369 return global_step_read_tensors[0] 370 return None 371 372 373def _get_or_create_global_step_read(graph=None): 374 """Gets or creates global step read tensor in graph. 375 376 Args: 377 graph: The graph in which to create the global step read tensor. If missing, 378 use default graph. 379 380 Returns: 381 Global step read tensor if there is global_step_tensor else return None. 382 """ 383 graph = graph or ops.get_default_graph() 384 global_step_read_tensor = _get_global_step_read(graph) 385 if global_step_read_tensor is not None: 386 return global_step_read_tensor 387 global_step_tensor = get_global_step(graph) 388 if global_step_tensor is None: 389 return None 390 # add 'zero' so that it will create a copy of variable as Tensor. 391 with graph.as_default() as g, g.name_scope(None): 392 with g.name_scope(global_step_tensor.op.name + '/'): 393 # using initialized_value to ensure that global_step is initialized before 394 # this run. This is needed for example Estimator makes all model_fn build 395 # under global_step_read_tensor dependency. 396 global_step_value = global_step_tensor.initialized_value() if isinstance( 397 global_step_tensor, variables.Variable) else global_step_tensor 398 global_step_read_tensor = global_step_value + 0 399 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) 400 return _get_global_step_read(graph) 401 402 403def _increment_global_step(increment, graph=None): 404 graph = graph or ops.get_default_graph() 405 global_step_tensor = get_global_step(graph) 406 if global_step_tensor is None: 407 raise ValueError( 408 'Global step tensor should be created by ' 409 'tf.train.get_or_create_global_step before calling increment.') 410 global_step_read_tensor = _get_or_create_global_step_read(graph) 411 with graph.as_default() as g, g.name_scope(None): 412 with g.name_scope(global_step_tensor.op.name + '/'): 413 with ops.control_dependencies([global_step_read_tensor]): 414 return state_ops.assign_add(global_step_tensor, increment) 415