xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/gcs_test/python/gcs_smoke.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Smoke test for reading records from GCS to TensorFlow."""
16import random
17import sys
18import time
19
20import numpy as np
21import tensorflow as tf
22from tensorflow.core.example import example_pb2
23from tensorflow.python.lib.io import file_io
24
25flags = tf.compat.v1.app.flags
26flags.DEFINE_string("gcs_bucket_url", "",
27                    "The URL to the GCS bucket in which the temporary "
28                    "tfrecord file is to be written and read, e.g., "
29                    "gs://my-gcs-bucket/test-directory")
30flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
31
32FLAGS = flags.FLAGS
33
34
35def create_examples(num_examples, input_mean):
36  """Create ExampleProto's containing data."""
37  ids = np.arange(num_examples).reshape([num_examples, 1])
38  inputs = np.random.randn(num_examples, 1) + input_mean
39  target = inputs - input_mean
40  examples = []
41  for row in range(num_examples):
42    ex = example_pb2.Example()
43    ex.features.feature["id"].bytes_list.value.append(bytes(ids[row, 0]))
44    ex.features.feature["target"].float_list.value.append(target[row, 0])
45    ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
46    examples.append(ex)
47  return examples
48
49
50def create_dir_test():
51  """Verifies file_io directory handling methods."""
52
53  # Test directory creation.
54  starttime_ms = int(round(time.time() * 1000))
55  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
56  print("Creating dir %s" % dir_name)
57  file_io.create_dir(dir_name)
58  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
59  print("Created directory in: %d milliseconds" % elapsed_ms)
60
61  # Check that the directory exists.
62  dir_exists = file_io.is_directory(dir_name)
63  assert dir_exists
64  print("%s directory exists: %s" % (dir_name, dir_exists))
65
66  # Test recursive directory creation.
67  starttime_ms = int(round(time.time() * 1000))
68  recursive_dir_name = "%s/%s/%s" % (dir_name,
69                                     "nested_dir1",
70                                     "nested_dir2")
71  print("Creating recursive dir %s" % recursive_dir_name)
72  file_io.recursive_create_dir(recursive_dir_name)
73  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
74  print("Created directory recursively in: %d milliseconds" % elapsed_ms)
75
76  # Check that the directory exists.
77  recursive_dir_exists = file_io.is_directory(recursive_dir_name)
78  assert recursive_dir_exists
79  print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists))
80
81  # Create some contents in the just created directory and list the contents.
82  num_files = 10
83  files_to_create = ["file_%d.txt" % n for n in range(num_files)]
84  for file_num in files_to_create:
85    file_name = "%s/%s" % (dir_name, file_num)
86    print("Creating file %s." % file_name)
87    file_io.write_string_to_file(file_name, "test file.")
88
89  print("Listing directory %s." % dir_name)
90  starttime_ms = int(round(time.time() * 1000))
91  directory_contents = file_io.list_directory(dir_name)
92  print(directory_contents)
93  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
94  print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms))
95  assert set(directory_contents) == set(files_to_create + ["nested_dir1/"])
96
97  # Test directory renaming.
98  dir_to_rename = "%s/old_dir" % dir_name
99  new_dir_name = "%s/new_dir" % dir_name
100  file_io.create_dir(dir_to_rename)
101  assert file_io.is_directory(dir_to_rename)
102  assert not file_io.is_directory(new_dir_name)
103
104  starttime_ms = int(round(time.time() * 1000))
105  print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name))
106  file_io.rename(dir_to_rename, new_dir_name)
107  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
108  print("Renamed directory %s to %s in %s milliseconds" % (
109      dir_to_rename, new_dir_name, elapsed_ms))
110  assert not file_io.is_directory(dir_to_rename)
111  assert file_io.is_directory(new_dir_name)
112
113  # Test Delete directory recursively.
114  print("Deleting directory recursively %s." % dir_name)
115  starttime_ms = int(round(time.time() * 1000))
116  file_io.delete_recursively(dir_name)
117  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
118  dir_exists = file_io.is_directory(dir_name)
119  assert not dir_exists
120  print("Deleted directory recursively %s in %s milliseconds" % (
121      dir_name, elapsed_ms))
122
123
124def create_object_test():
125  """Verifies file_io's object manipulation methods ."""
126  starttime_ms = int(round(time.time() * 1000))
127  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
128  print("Creating dir %s." % dir_name)
129  file_io.create_dir(dir_name)
130
131  num_files = 5
132  # Create files of 2 different patterns in this directory.
133  files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n)
134                     for n in range(num_files)]
135  files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n)
136                     for n in range(num_files)]
137
138  starttime_ms = int(round(time.time() * 1000))
139  files_to_create = files_pattern_1 + files_pattern_2
140  for file_name in files_to_create:
141    print("Creating file %s." % file_name)
142    file_io.write_string_to_file(file_name, "test file creation.")
143  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
144  print("Created %d files in %s milliseconds" % (
145      len(files_to_create), elapsed_ms))
146
147  # Listing files of pattern1.
148  list_files_pattern = "%s/test_file*.txt" % dir_name
149  print("Getting files matching pattern %s." % list_files_pattern)
150  starttime_ms = int(round(time.time() * 1000))
151  files_list = file_io.get_matching_files(list_files_pattern)
152  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
153  print("Listed files in %s milliseconds" % elapsed_ms)
154  print(files_list)
155  assert set(files_list) == set(files_pattern_1)
156
157  # Listing files of pattern2.
158  list_files_pattern = "%s/testfile*.txt" % dir_name
159  print("Getting files matching pattern %s." % list_files_pattern)
160  starttime_ms = int(round(time.time() * 1000))
161  files_list = file_io.get_matching_files(list_files_pattern)
162  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
163  print("Listed files in %s milliseconds" % elapsed_ms)
164  print(files_list)
165  assert set(files_list) == set(files_pattern_2)
166
167  # Test renaming file.
168  file_to_rename = "%s/oldname.txt" % dir_name
169  file_new_name = "%s/newname.txt" % dir_name
170  file_io.write_string_to_file(file_to_rename, "test file.")
171  assert file_io.file_exists(file_to_rename)
172  assert not file_io.file_exists(file_new_name)
173
174  print("Will try renaming file %s to %s" % (file_to_rename, file_new_name))
175  starttime_ms = int(round(time.time() * 1000))
176  file_io.rename(file_to_rename, file_new_name)
177  elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
178  print("File %s renamed to %s in %s milliseconds" % (
179      file_to_rename, file_new_name, elapsed_ms))
180  assert not file_io.file_exists(file_to_rename)
181  assert file_io.file_exists(file_new_name)
182
183  # Delete directory.
184  print("Deleting directory %s." % dir_name)
185  file_io.delete_recursively(dir_name)
186
187
188def main(argv):
189  del argv  # Unused.
190
191  # Sanity check on the GCS bucket URL.
192  if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
193    print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
194    sys.exit(1)
195
196  # Generate random tfrecord path name.
197  input_path = FLAGS.gcs_bucket_url + "/"
198  input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8))
199  input_path += ".tfrecord"
200  print("Using input path: %s" % input_path)
201
202  # Verify that writing to the records file in GCS works.
203  print("\n=== Testing writing and reading of GCS record file... ===")
204  example_data = create_examples(FLAGS.num_examples, 5)
205  with tf.io.TFRecordWriter(input_path) as hf:
206    for e in example_data:
207      hf.write(e.SerializeToString())
208
209    print("Data written to: %s" % input_path)
210
211  # Verify that reading from the tfrecord file works and that
212  # tf_record_iterator works.
213  record_iter = tf.compat.v1.python_io.tf_record_iterator(input_path)
214  read_count = 0
215  for _ in record_iter:
216    read_count += 1
217  print("Read %d records using tf_record_iterator" % read_count)
218
219  if read_count != FLAGS.num_examples:
220    print("FAIL: The number of records read from tf_record_iterator (%d) "
221          "differs from the expected number (%d)" % (read_count,
222                                                     FLAGS.num_examples))
223    sys.exit(1)
224
225  # Verify that running the read op in a session works.
226  print("\n=== Testing TFRecordReader.read op in a session... ===")
227  with tf.Graph().as_default():
228    filename_queue = tf.compat.v1.train.string_input_producer([input_path],
229                                                              num_epochs=1)
230    reader = tf.compat.v1.TFRecordReader()
231    _, serialized_example = reader.read(filename_queue)
232
233    with tf.compat.v1.Session() as sess:
234      sess.run(tf.compat.v1.global_variables_initializer())
235      sess.run(tf.compat.v1.local_variables_initializer())
236      tf.compat.v1.train.start_queue_runners()
237      index = 0
238      for _ in range(FLAGS.num_examples):
239        print("Read record: %d" % index)
240        sess.run(serialized_example)
241        index += 1
242
243      # Reading one more record should trigger an exception.
244      try:
245        sess.run(serialized_example)
246        print("FAIL: Failed to catch the expected OutOfRangeError while "
247              "reading one more record than is available")
248        sys.exit(1)
249      except tf.errors.OutOfRangeError:
250        print("Successfully caught the expected OutOfRangeError while "
251              "reading one more record than is available")
252
253  create_dir_test()
254  create_object_test()
255
256
257if __name__ == "__main__":
258  tf.compat.v1.app.run(main)
259