xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/single_loss_example.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A simple network to use in tests and examples."""
16
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.distribute import step_fn
19from tensorflow.python.distribute import strategy_test_lib
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import ops
22from tensorflow.python.layers import core
23from tensorflow.python.layers import normalization
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26
27
28def single_loss_example(optimizer_fn, distribution, use_bias=False,
29                        iterations_per_step=1):
30  """Build a very simple network to use in tests and examples."""
31
32  def dataset_fn():
33    return dataset_ops.Dataset.from_tensors([[1.]]).repeat()
34
35  optimizer = optimizer_fn()
36  layer = core.Dense(1, use_bias=use_bias)
37
38  def loss_fn(ctx, x):
39    del ctx
40    y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
41    return y * y
42
43  single_loss_step = step_fn.StandardSingleLossStep(
44      dataset_fn, loss_fn, optimizer, distribution, iterations_per_step)
45
46  # Layer is returned for inspecting the kernels in tests.
47  return single_loss_step, layer
48
49
50def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
51  """Example of non-distribution-aware legacy code."""
52
53  def dataset_fn():
54    dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
55    # TODO(isaprykin): batch with drop_remainder causes shapes to be
56    # fully defined for TPU.  Remove this when XLA supports dynamic shapes.
57    return dataset.batch(1, drop_remainder=True)
58
59  layer = core.Dense(1, use_bias=use_bias)
60
61  def model_fn(x):
62    """A very simple model written by the user."""
63
64    def loss_fn():
65      y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
66      return y * y
67
68    if strategy_test_lib.is_optimizer_v2_instance(optimizer):
69      return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
70    elif use_callable_loss:
71      return optimizer.minimize(loss_fn)
72    else:
73      return optimizer.minimize(loss_fn())
74
75  return model_fn, dataset_fn, layer
76
77
78def batchnorm_example(optimizer_fn,
79                      batch_per_epoch=1,
80                      momentum=0.9,
81                      renorm=False,
82                      update_ops_in_replica_mode=False):
83  """Example of non-distribution-aware legacy code with batch normalization."""
84
85  def dataset_fn():
86    # input shape is [16, 8], input values are increasing in both dimensions.
87    return dataset_ops.Dataset.from_tensor_slices(
88        [[[float(x * 8 + y + z * 100)
89           for y in range(8)]
90          for x in range(16)]
91         for z in range(batch_per_epoch)]).repeat()
92
93  optimizer = optimizer_fn()
94  batchnorm = normalization.BatchNormalization(
95      renorm=renorm, momentum=momentum, fused=False)
96  layer = core.Dense(1, use_bias=False)
97
98  def model_fn(x):
99    """A model that uses batchnorm."""
100
101    def loss_fn():
102      y = batchnorm(x, training=True)
103      with ops.control_dependencies(
104          ops.get_collection(ops.GraphKeys.UPDATE_OPS)
105          if update_ops_in_replica_mode else []):
106        loss = math_ops.reduce_mean(
107            math_ops.reduce_sum(layer(y)) - constant_op.constant(1.))
108      # `x` and `y` will be fetched by the gradient computation, but not `loss`.
109      return loss
110
111    if strategy_test_lib.is_optimizer_v2_instance(optimizer):
112      return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
113
114    # Callable loss.
115    return optimizer.minimize(loss_fn)
116
117  return model_fn, dataset_fn, batchnorm
118