xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/examples/v2/debug_mnist_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Demo of the tfdbg curses CLI: Locating the source of bad numerical values with TF v2.
16
17This demo contains a classical example of a neural network for the mnist
18dataset, but modifications are made so that problematic numerical values (infs
19and nans) appear in nodes of the graph during training.
20"""
21import argparse
22import sys
23
24import absl
25import tensorflow.compat.v2 as tf
26
27IMAGE_SIZE = 28
28HIDDEN_SIZE = 500
29NUM_LABELS = 10
30
31# If we set the weights randomly, the model will converge normally about half
32# the time. We need a seed to ensure that the bad numerical values issue
33# appears.
34RAND_SEED = 42
35
36tf.compat.v1.enable_v2_behavior()
37
38FLAGS = None
39
40
41def parse_args():
42  """Parses commandline arguments.
43
44  Returns:
45    A tuple (parsed, unparsed) of the parsed object and a group of unparsed
46      arguments that did not match the parser.
47  """
48  parser = argparse.ArgumentParser()
49  parser.register("type", "bool", lambda v: v.lower() == "true")
50  parser.add_argument(
51      "--max_steps",
52      type=int,
53      default=10,
54      help="Number of steps to run trainer.")
55  parser.add_argument(
56      "--train_batch_size",
57      type=int,
58      default=100,
59      help="Batch size used during training.")
60  parser.add_argument(
61      "--learning_rate",
62      type=float,
63      default=0.025,
64      help="Initial learning rate.")
65  parser.add_argument(
66      "--data_dir",
67      type=str,
68      default="/tmp/mnist_data",
69      help="Directory for storing data")
70  parser.add_argument(
71      "--fake_data",
72      type="bool",
73      nargs="?",
74      const=True,
75      default=False,
76      help="Use fake MNIST data for unit testing")
77  parser.add_argument(
78      "--check_numerics",
79      type="bool",
80      nargs="?",
81      const=True,
82      default=False,
83      help="Use tfdbg to track down bad values during training. "
84      "Mutually exclusive with the --dump_dir flag.")
85  parser.add_argument(
86      "--dump_dir",
87      type=str,
88      default=None,
89      help="Dump TensorFlow program debug data to the specified directory. "
90      "The dumped data contains information regarding tf.function building, "
91      "execution of ops and tf.functions, as well as their stack traces and "
92      "associated source-code snapshots. "
93      "Mutually exclusive with the --check_numerics flag.")
94  parser.add_argument(
95      "--dump_tensor_debug_mode",
96      type=str,
97      default="FULL_HEALTH",
98      help="Mode for dumping tensor values. Options: NO_TENSOR, CURT_HEALTH, "
99      "CONCISE_HEALTH, SHAPE, FULL_HEALTH. This is relevant only when "
100      "--dump_dir is set.")
101  # TODO(cais): Add more tensor debug mode strings once they are supported.
102  parser.add_argument(
103      "--dump_circular_buffer_size",
104      type=int,
105      default=-1,
106      help="Size of the circular buffer used to dump execution events. "
107      "A value <= 0 disables the circular-buffer behavior and causes "
108      "all instrumented tensor values to be dumped. "
109      "This is relevant only when --dump_dir is set.")
110  parser.add_argument(
111      "--use_random_config_path",
112      type="bool",
113      nargs="?",
114      const=True,
115      default=False,
116      help="""If set, set config file path to a random file in the temporary
117      directory.""")
118  return parser.parse_known_args()
119
120
121def main(_):
122  if FLAGS.check_numerics and FLAGS.dump_dir:
123    raise ValueError(
124        "The --check_numerics and --dump_dir flags are mutually "
125        "exclusive.")
126  if FLAGS.check_numerics:
127    tf.debugging.enable_check_numerics()
128  elif FLAGS.dump_dir:
129    tf.debugging.experimental.enable_dump_debug_info(
130        FLAGS.dump_dir,
131        tensor_debug_mode=FLAGS.dump_tensor_debug_mode,
132        circular_buffer_size=FLAGS.dump_circular_buffer_size)
133
134  # Import data
135  if FLAGS.fake_data:
136    imgs = tf.random.uniform(maxval=256, shape=(1000, 28, 28), dtype=tf.int32)
137    labels = tf.random.uniform(maxval=10, shape=(1000,), dtype=tf.int32)
138    mnist_train = imgs, labels
139    mnist_test = imgs, labels
140  else:
141    mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
142
143  @tf.function
144  def format_example(imgs, labels):
145    """Formats each training and test example to work with our model."""
146    imgs = tf.reshape(imgs, [-1, 28 * 28])
147    imgs = tf.cast(imgs, tf.float32) / 255.0
148    labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
149    return imgs, labels
150
151  train_ds = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(
152      FLAGS.train_batch_size * FLAGS.max_steps,
153      seed=RAND_SEED).batch(FLAGS.train_batch_size)
154  train_ds = train_ds.map(format_example)
155
156  test_ds = tf.data.Dataset.from_tensor_slices(mnist_test).repeat().batch(
157      len(mnist_test[0]))
158  test_ds = test_ds.map(format_example)
159
160  def get_dense_weights(input_dim, output_dim):
161    """Initializes the parameters for a single dense layer."""
162    initial_kernel = tf.keras.initializers.TruncatedNormal(
163        mean=0.0, stddev=0.1, seed=RAND_SEED)
164    kernel = tf.Variable(initial_kernel([input_dim, output_dim]))
165    bias = tf.Variable(tf.constant(0.1, shape=[output_dim]))
166
167    return kernel, bias
168
169  @tf.function
170  def dense_layer(weights, input_tensor, act=tf.nn.relu):
171    """Runs the forward computation for a single dense layer."""
172    kernel, bias = weights
173    preactivate = tf.matmul(input_tensor, kernel) + bias
174
175    activations = act(preactivate)
176    return activations
177
178  # init model
179  hidden_weights = get_dense_weights(IMAGE_SIZE**2, HIDDEN_SIZE)
180  output_weights = get_dense_weights(HIDDEN_SIZE, NUM_LABELS)
181  variables = hidden_weights + output_weights
182
183  @tf.function
184  def model(x):
185    """Feed forward function of the model.
186
187    Args:
188      x: a (?, 28*28) tensor consisting of the feature inputs for a batch of
189        examples.
190
191    Returns:
192      A (?, 10) tensor containing the class scores for each example.
193    """
194    hidden_act = dense_layer(hidden_weights, x)
195    logits_act = dense_layer(output_weights, hidden_act, tf.identity)
196    y = tf.nn.softmax(logits_act)
197    return y
198
199  @tf.function
200  def loss(probs, labels):
201    """Calculates cross entropy loss.
202
203    Args:
204      probs: Class probabilities predicted by the model. The shape is expected
205        to be (?, 10).
206      labels: Truth labels for the classes, as one-hot encoded vectors. The
207        shape is expected to be the same as `probs`.
208
209    Returns:
210      A scalar loss tensor.
211    """
212    diff = -labels * tf.math.log(probs)
213    loss = tf.reduce_mean(diff)
214    return loss
215
216  train_batches = iter(train_ds)
217  test_batches = iter(test_ds)
218  optimizer = tf.optimizers.Adam(learning_rate=FLAGS.learning_rate)
219  for i in range(FLAGS.max_steps):
220    x_train, y_train = next(train_batches)
221    x_test, y_test = next(test_batches)
222
223    # Train Step
224    with tf.GradientTape() as tape:
225      y = model(x_train)
226      loss_val = loss(y, y_train)
227    grads = tape.gradient(loss_val, variables)
228
229    optimizer.apply_gradients(zip(grads, variables))
230
231    # Evaluation Step
232    y = model(x_test)
233    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_test, 1))
234    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
235    print("Accuracy at step %d: %s" % (i, accuracy.numpy()))
236
237
238if __name__ == "__main__":
239  FLAGS, unparsed = parse_args()
240  absl.app.run(main=main, argv=[sys.argv[0]] + unparsed)
241