xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/examples/v1/debug_mnist_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Demo of the tfdbg curses CLI: Locating the source of bad numerical values.
16
17The neural network in this demo is larged based on the tutorial at:
18  tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
19
20But modifications are made so that problematic numerical values (infs and nans)
21appear in nodes of the graph during training.
22"""
23import argparse
24import sys
25import tempfile
26
27import tensorflow
28
29from tensorflow.python import debug as tf_debug
30
31tf = tensorflow.compat.v1
32
33IMAGE_SIZE = 28
34HIDDEN_SIZE = 500
35NUM_LABELS = 10
36RAND_SEED = 42
37
38FLAGS = None
39
40
41def parse_args():
42  """Parses commandline arguments.
43
44  Returns:
45    A tuple (parsed, unparsed) of the parsed object and a group of unparsed
46      arguments that did not match the parser.
47  """
48  parser = argparse.ArgumentParser()
49  parser.register("type", "bool", lambda v: v.lower() == "true")
50  parser.add_argument(
51      "--max_steps",
52      type=int,
53      default=10,
54      help="Number of steps to run trainer.")
55  parser.add_argument(
56      "--train_batch_size",
57      type=int,
58      default=100,
59      help="Batch size used during training.")
60  parser.add_argument(
61      "--learning_rate",
62      type=float,
63      default=0.025,
64      help="Initial learning rate.")
65  parser.add_argument(
66      "--data_dir",
67      type=str,
68      default="/tmp/mnist_data",
69      help="Directory for storing data")
70  parser.add_argument(
71      "--ui_type",
72      type=str,
73      default="curses",
74      help="Command-line user interface type (curses | readline)")
75  parser.add_argument(
76      "--fake_data",
77      type="bool",
78      nargs="?",
79      const=True,
80      default=False,
81      help="Use fake MNIST data for unit testing")
82  parser.add_argument(
83      "--debug",
84      type="bool",
85      nargs="?",
86      const=True,
87      default=False,
88      help="Use debugger to track down bad values during training. "
89      "Mutually exclusive with the --tensorboard_debug_address flag.")
90  parser.add_argument(
91      "--tensorboard_debug_address",
92      type=str,
93      default=None,
94      help="Connect to the TensorBoard Debugger Plugin backend specified by "
95      "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
96      "--debug flag.")
97  parser.add_argument(
98      "--use_random_config_path",
99      type="bool",
100      nargs="?",
101      const=True,
102      default=False,
103      help="""If set, set config file path to a random file in the temporary
104      directory.""")
105  return parser.parse_known_args()
106
107
108def main(_):
109  # Import data
110  if FLAGS.fake_data:
111    imgs = tf.random.uniform(maxval=256, shape=(10, 28, 28), dtype=tf.int32)
112    labels = tf.random.uniform(maxval=10, shape=(10,), dtype=tf.int32)
113    mnist_train = imgs, labels
114    mnist_test = imgs, labels
115  else:
116    mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
117
118  def format_example(imgs, labels):
119    imgs = tf.reshape(imgs, [-1, 28 * 28])
120    imgs = tf.cast(imgs, tf.float32) / 255.0
121    labels = tf.one_hot(labels, depth=10, dtype=tf.float32)
122    return imgs, labels
123
124  ds_train = tf.data.Dataset.from_tensor_slices(mnist_train)
125  ds_train = ds_train.shuffle(
126      1000, seed=RAND_SEED).repeat().batch(FLAGS.train_batch_size)
127  ds_train = ds_train.map(format_example)
128  it_train = ds_train.make_initializable_iterator()
129
130  ds_test = tf.data.Dataset.from_tensors(mnist_test).repeat()
131  ds_test = ds_test.map(format_example)
132  it_test = ds_test.make_initializable_iterator()
133
134  sess = tf.InteractiveSession()
135
136  # Create the MNIST neural network graph.
137
138  # Input placeholders.
139  with tf.name_scope("input"):
140    handle = tf.placeholder(tf.string, shape=())
141
142    iterator = tf.data.Iterator.from_string_handle(
143        handle, (tf.float32, tf.float32),
144        ((None, IMAGE_SIZE * IMAGE_SIZE), (None, 10)))
145
146    x, y_ = iterator.get_next()
147
148  def weight_variable(shape):
149    """Create a weight variable with appropriate initialization."""
150    initial = tf.truncated_normal(shape, stddev=0.1, seed=RAND_SEED)
151    return tf.Variable(initial)
152
153  def bias_variable(shape):
154    """Create a bias variable with appropriate initialization."""
155    initial = tf.constant(0.1, shape=shape)
156    return tf.Variable(initial)
157
158  def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
159    """Reusable code for making a simple neural net layer."""
160    # Adding a name scope ensures logical grouping of the layers in the graph.
161    with tf.name_scope(layer_name):
162      # This Variable will hold the state of the weights for the layer
163      with tf.name_scope("weights"):
164        weights = weight_variable([input_dim, output_dim])
165      with tf.name_scope("biases"):
166        biases = bias_variable([output_dim])
167      with tf.name_scope("Wx_plus_b"):
168        preactivate = tf.matmul(input_tensor, weights) + biases
169
170      activations = act(preactivate)
171      return activations
172
173  hidden = nn_layer(x, IMAGE_SIZE**2, HIDDEN_SIZE, "hidden")
174  logits = nn_layer(hidden, HIDDEN_SIZE, NUM_LABELS, "output", tf.identity)
175  y = tf.nn.softmax(logits)
176
177  with tf.name_scope("cross_entropy"):
178    # The following line is the culprit of the bad numerical values that appear
179    # during training of this graph. Log of zero gives inf, which is first seen
180    # in the intermediate tensor "cross_entropy/Log:0" during the 4th run()
181    # call. A multiplication of the inf values with zeros leads to nans,
182    # which is first in "cross_entropy/mul:0".
183    #
184    # You can use the built-in, numerically-stable implementation to fix this
185    # issue:
186    #   diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits)
187
188    diff = -(y_ * tf.log(y))
189    with tf.name_scope("total"):
190      cross_entropy = tf.reduce_mean(diff)
191
192  with tf.name_scope("train"):
193    train_step = tf.train.AdamOptimizer(
194        FLAGS.learning_rate).minimize(cross_entropy)
195
196  with tf.name_scope("accuracy"):
197    with tf.name_scope("correct_prediction"):
198      correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
199    with tf.name_scope("accuracy"):
200      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
201
202  sess.run(tf.global_variables_initializer())
203  sess.run(it_train.initializer)
204  sess.run(it_test.initializer)
205  train_handle = sess.run(it_train.string_handle())
206  test_handle = sess.run(it_test.string_handle())
207
208  if FLAGS.debug and FLAGS.tensorboard_debug_address:
209    raise ValueError(
210        "The --debug and --tensorboard_debug_address flags are mutually "
211        "exclusive.")
212  if FLAGS.debug:
213    if FLAGS.use_random_config_path:
214      _, config_file_path = tempfile.mkstemp(".tfdbg_config")
215    else:
216      config_file_path = None
217    sess = tf_debug.LocalCLIDebugWrapperSession(
218        sess, ui_type=FLAGS.ui_type, config_file_path=config_file_path)
219  elif FLAGS.tensorboard_debug_address:
220    sess = tf_debug.TensorBoardDebugWrapperSession(
221        sess, FLAGS.tensorboard_debug_address)
222
223  # Add this point, sess is a debug wrapper around the actual Session if
224  # FLAGS.debug is true. In that case, calling run() will launch the CLI.
225  for i in range(FLAGS.max_steps):
226    acc = sess.run(accuracy, feed_dict={handle: test_handle})
227    print("Accuracy at step %d: %s" % (i, acc))
228
229    sess.run(train_step, feed_dict={handle: train_handle})
230
231
232if __name__ == "__main__":
233  FLAGS, unparsed = parse_args()
234  with tf.Graph().as_default():
235    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
236