xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/training_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for training."""
16from tensorflow.python.eager import context
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import graph_io
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import init_ops
21from tensorflow.python.ops import resource_variable_ops
22from tensorflow.python.ops import state_ops
23from tensorflow.python.ops import variable_scope
24from tensorflow.python.ops import variables
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util.tf_export import tf_export
27
28# Picked a long key value to minimize the chance of collision with user defined
29# collection keys.
30GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
31
32# TODO(drpng): remove this after legacy uses are resolved.
33write_graph = graph_io.write_graph
34
35
36@tf_export(v1=['train.global_step'])
37def global_step(sess, global_step_tensor):
38  """Small helper to get the global step.
39
40  ```python
41  # Create a variable to hold the global_step.
42  global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
43  # Create a session.
44  sess = tf.compat.v1.Session()
45  # Initialize the variable
46  sess.run(global_step_tensor.initializer)
47  # Get the variable value.
48  print('global_step: %s' % tf.compat.v1.train.global_step(sess,
49  global_step_tensor))
50
51  global_step: 10
52  ```
53
54  Args:
55    sess: A TensorFlow `Session` object.
56    global_step_tensor:  `Tensor` or the `name` of the operation that contains
57      the global step.
58
59  Returns:
60    The global step value.
61  """
62  if context.executing_eagerly():
63    return int(global_step_tensor.numpy())
64  return int(sess.run(global_step_tensor))
65
66
67@tf_export(v1=['train.get_global_step'])
68def get_global_step(graph=None):
69  """Get the global step tensor.
70
71  The global step tensor must be an integer variable. We first try to find it
72  in the collection `GLOBAL_STEP`, or by name `global_step:0`.
73
74  Args:
75    graph: The graph to find the global step in. If missing, use default graph.
76
77  Returns:
78    The global step variable, or `None` if none was found.
79
80  Raises:
81    TypeError: If the global step tensor has a non-integer type, or if it is not
82      a `Variable`.
83
84  @compatibility(TF2)
85  With the deprecation of global graphs, TF no longer tracks variables in
86  collections. In other words, there are no global variables in TF2. Thus, the
87  global step functions have been removed  (`get_or_create_global_step`,
88  `create_global_step`, `get_global_step`) . You have two options for migrating:
89
90  1. Create a Keras optimizer, which generates an `iterations` variable. This
91     variable is automatically incremented when calling `apply_gradients`.
92  2. Manually create and increment a `tf.Variable`.
93
94  Below is an example of migrating away from using a global step to using a
95  Keras optimizer:
96
97  Define a dummy model and loss:
98
99  >>> def compute_loss(x):
100  ...   v = tf.Variable(3.0)
101  ...   y = x * v
102  ...   loss = x * 5 - x * v
103  ...   return loss, [v]
104
105  Before migrating:
106
107  >>> g = tf.Graph()
108  >>> with g.as_default():
109  ...   x = tf.compat.v1.placeholder(tf.float32, [])
110  ...   loss, var_list = compute_loss(x)
111  ...   global_step = tf.compat.v1.train.get_or_create_global_step()
112  ...   global_init = tf.compat.v1.global_variables_initializer()
113  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
114  ...   train_op = optimizer.minimize(loss, global_step, var_list)
115  >>> sess = tf.compat.v1.Session(graph=g)
116  >>> sess.run(global_init)
117  >>> print("before training:", sess.run(global_step))
118  before training: 0
119  >>> sess.run(train_op, feed_dict={x: 3})
120  >>> print("after training:", sess.run(global_step))
121  after training: 1
122
123  Using `get_global_step`:
124
125  >>> with g.as_default():
126  ...   print(sess.run(tf.compat.v1.train.get_global_step()))
127  1
128
129  Migrating to a Keras optimizer:
130
131  >>> optimizer = tf.keras.optimizers.SGD(.01)
132  >>> print("before training:", optimizer.iterations.numpy())
133  before training: 0
134  >>> with tf.GradientTape() as tape:
135  ...   loss, var_list = compute_loss(3)
136  ...   grads = tape.gradient(loss, var_list)
137  ...   optimizer.apply_gradients(zip(grads, var_list))
138  >>> print("after training:", optimizer.iterations.numpy())
139  after training: 1
140
141  @end_compatibility
142  """
143  graph = graph or ops.get_default_graph()
144  global_step_tensor = None
145  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
146  if len(global_step_tensors) == 1:
147    global_step_tensor = global_step_tensors[0]
148  elif not global_step_tensors:
149    try:
150      global_step_tensor = graph.get_tensor_by_name('global_step:0')
151    except KeyError:
152      return None
153  else:
154    logging.error('Multiple tensors in global_step collection.')
155    return None
156
157  assert_global_step(global_step_tensor)
158  return global_step_tensor
159
160
161@tf_export(v1=['train.create_global_step'])
162def create_global_step(graph=None):
163  """Create global step tensor in graph.
164
165  Args:
166    graph: The graph in which to create the global step tensor. If missing, use
167      default graph.
168
169  Returns:
170    Global step tensor.
171
172  Raises:
173    ValueError: if global step tensor is already defined.
174
175  @compatibility(TF2)
176  With the deprecation of global graphs, TF no longer tracks variables in
177  collections. In other words, there are no global variables in TF2. Thus, the
178  global step functions have been removed  (`get_or_create_global_step`,
179  `create_global_step`, `get_global_step`) . You have two options for migrating:
180
181  1. Create a Keras optimizer, which generates an `iterations` variable. This
182     variable is automatically incremented when calling `apply_gradients`.
183  2. Manually create and increment a `tf.Variable`.
184
185  Below is an example of migrating away from using a global step to using a
186  Keras optimizer:
187
188  Define a dummy model and loss:
189
190  >>> def compute_loss(x):
191  ...   v = tf.Variable(3.0)
192  ...   y = x * v
193  ...   loss = x * 5 - x * v
194  ...   return loss, [v]
195
196  Before migrating:
197
198  >>> g = tf.Graph()
199  >>> with g.as_default():
200  ...   x = tf.compat.v1.placeholder(tf.float32, [])
201  ...   loss, var_list = compute_loss(x)
202  ...   global_step = tf.compat.v1.train.create_global_step()
203  ...   global_init = tf.compat.v1.global_variables_initializer()
204  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
205  ...   train_op = optimizer.minimize(loss, global_step, var_list)
206  >>> sess = tf.compat.v1.Session(graph=g)
207  >>> sess.run(global_init)
208  >>> print("before training:", sess.run(global_step))
209  before training: 0
210  >>> sess.run(train_op, feed_dict={x: 3})
211  >>> print("after training:", sess.run(global_step))
212  after training: 1
213
214  Migrating to a Keras optimizer:
215
216  >>> optimizer = tf.keras.optimizers.SGD(.01)
217  >>> print("before training:", optimizer.iterations.numpy())
218  before training: 0
219  >>> with tf.GradientTape() as tape:
220  ...   loss, var_list = compute_loss(3)
221  ...   grads = tape.gradient(loss, var_list)
222  ...   optimizer.apply_gradients(zip(grads, var_list))
223  >>> print("after training:", optimizer.iterations.numpy())
224  after training: 1
225
226  @end_compatibility
227  """
228  graph = graph or ops.get_default_graph()
229  if get_global_step(graph) is not None:
230    raise ValueError('"global_step" already exists.')
231  if context.executing_eagerly():
232    with ops.device('cpu:0'):
233      return variable_scope.get_variable(
234          ops.GraphKeys.GLOBAL_STEP,
235          shape=[],
236          dtype=dtypes.int64,
237          initializer=init_ops.zeros_initializer(),
238          trainable=False,
239          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
240          collections=[
241              ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
242          ])
243  # Create in proper graph and base name_scope.
244  with graph.as_default() as g, g.name_scope(None):
245    return variable_scope.get_variable(
246        ops.GraphKeys.GLOBAL_STEP,
247        shape=[],
248        dtype=dtypes.int64,
249        initializer=init_ops.zeros_initializer(),
250        trainable=False,
251        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
252        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
253
254
255@tf_export(v1=['train.get_or_create_global_step'])
256def get_or_create_global_step(graph=None):
257  """Returns and create (if necessary) the global step tensor.
258
259  Args:
260    graph: The graph in which to create the global step tensor. If missing, use
261      default graph.
262
263  Returns:
264    The global step tensor.
265
266  @compatibility(TF2)
267  With the deprecation of global graphs, TF no longer tracks variables in
268  collections. In other words, there are no global variables in TF2. Thus, the
269  global step functions have been removed  (`get_or_create_global_step`,
270  `create_global_step`, `get_global_step`) . You have two options for migrating:
271
272  1. Create a Keras optimizer, which generates an `iterations` variable. This
273     variable is automatically incremented when calling `apply_gradients`.
274  2. Manually create and increment a `tf.Variable`.
275
276  Below is an example of migrating away from using a global step to using a
277  Keras optimizer:
278
279  Define a dummy model and loss:
280
281  >>> def compute_loss(x):
282  ...   v = tf.Variable(3.0)
283  ...   y = x * v
284  ...   loss = x * 5 - x * v
285  ...   return loss, [v]
286
287  Before migrating:
288
289  >>> g = tf.Graph()
290  >>> with g.as_default():
291  ...   x = tf.compat.v1.placeholder(tf.float32, [])
292  ...   loss, var_list = compute_loss(x)
293  ...   global_step = tf.compat.v1.train.get_or_create_global_step()
294  ...   global_init = tf.compat.v1.global_variables_initializer()
295  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
296  ...   train_op = optimizer.minimize(loss, global_step, var_list)
297  >>> sess = tf.compat.v1.Session(graph=g)
298  >>> sess.run(global_init)
299  >>> print("before training:", sess.run(global_step))
300  before training: 0
301  >>> sess.run(train_op, feed_dict={x: 3})
302  >>> print("after training:", sess.run(global_step))
303  after training: 1
304
305  Migrating to a Keras optimizer:
306
307  >>> optimizer = tf.keras.optimizers.SGD(.01)
308  >>> print("before training:", optimizer.iterations.numpy())
309  before training: 0
310  >>> with tf.GradientTape() as tape:
311  ...   loss, var_list = compute_loss(3)
312  ...   grads = tape.gradient(loss, var_list)
313  ...   optimizer.apply_gradients(zip(grads, var_list))
314  >>> print("after training:", optimizer.iterations.numpy())
315  after training: 1
316
317  @end_compatibility
318  """
319  graph = graph or ops.get_default_graph()
320  global_step_tensor = get_global_step(graph)
321  if global_step_tensor is None:
322    global_step_tensor = create_global_step(graph)
323  return global_step_tensor
324
325
326@tf_export(v1=['train.assert_global_step'])
327def assert_global_step(global_step_tensor):
328  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
329
330  Args:
331    global_step_tensor: `Tensor` to test.
332  """
333  if not (isinstance(global_step_tensor, variables.Variable) or
334          isinstance(global_step_tensor, ops.Tensor) or
335          resource_variable_ops.is_resource_variable(global_step_tensor)):
336    raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
337                    global_step_tensor)
338
339  if not global_step_tensor.dtype.base_dtype.is_integer:
340    raise TypeError('Existing "global_step" does not have integer type: %s' %
341                    global_step_tensor.dtype)
342
343  if (global_step_tensor.get_shape().ndims != 0 and
344      global_step_tensor.get_shape().is_fully_defined()):
345    raise TypeError('Existing "global_step" is not scalar: %s' %
346                    global_step_tensor.get_shape())
347
348
349def _get_global_step_read(graph=None):
350  """Gets global step read tensor in graph.
351
352  Args:
353    graph: The graph in which to create the global step read tensor. If missing,
354      use default graph.
355
356  Returns:
357    Global step read tensor.
358
359  Raises:
360    RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY.
361  """
362  graph = graph or ops.get_default_graph()
363  global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY)
364  if len(global_step_read_tensors) > 1:
365    raise RuntimeError('There are multiple items in collection {}. '
366                       'There should be only one.'.format(GLOBAL_STEP_READ_KEY))
367
368  if len(global_step_read_tensors) == 1:
369    return global_step_read_tensors[0]
370  return None
371
372
373def _get_or_create_global_step_read(graph=None):
374  """Gets or creates global step read tensor in graph.
375
376  Args:
377    graph: The graph in which to create the global step read tensor. If missing,
378      use default graph.
379
380  Returns:
381    Global step read tensor if there is global_step_tensor else return None.
382  """
383  graph = graph or ops.get_default_graph()
384  global_step_read_tensor = _get_global_step_read(graph)
385  if global_step_read_tensor is not None:
386    return global_step_read_tensor
387  global_step_tensor = get_global_step(graph)
388  if global_step_tensor is None:
389    return None
390  # add 'zero' so that it will create a copy of variable as Tensor.
391  with graph.as_default() as g, g.name_scope(None):
392    with g.name_scope(global_step_tensor.op.name + '/'):
393      # using initialized_value to ensure that global_step is initialized before
394      # this run. This is needed for example Estimator makes all model_fn build
395      # under global_step_read_tensor dependency.
396      global_step_value = global_step_tensor.initialized_value() if isinstance(
397          global_step_tensor, variables.Variable) else global_step_tensor
398      global_step_read_tensor = global_step_value + 0
399      ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
400  return _get_global_step_read(graph)
401
402
403def _increment_global_step(increment, graph=None):
404  graph = graph or ops.get_default_graph()
405  global_step_tensor = get_global_step(graph)
406  if global_step_tensor is None:
407    raise ValueError(
408        'Global step tensor should be created by '
409        'tf.train.get_or_create_global_step before calling increment.')
410  global_step_read_tensor = _get_or_create_global_step_read(graph)
411  with graph.as_default() as g, g.name_scope(None):
412    with g.name_scope(global_step_tensor.op.name + '/'):
413      with ops.control_dependencies([global_step_read_tensor]):
414        return state_ops.assign_add(global_step_tensor, increment)
415