1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Smoke test for reading records from GCS to TensorFlow.""" 16import random 17import sys 18import time 19 20import numpy as np 21import tensorflow as tf 22from tensorflow.core.example import example_pb2 23from tensorflow.python.lib.io import file_io 24 25flags = tf.compat.v1.app.flags 26flags.DEFINE_string("gcs_bucket_url", "", 27 "The URL to the GCS bucket in which the temporary " 28 "tfrecord file is to be written and read, e.g., " 29 "gs://my-gcs-bucket/test-directory") 30flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") 31 32FLAGS = flags.FLAGS 33 34 35def create_examples(num_examples, input_mean): 36 """Create ExampleProto's containing data.""" 37 ids = np.arange(num_examples).reshape([num_examples, 1]) 38 inputs = np.random.randn(num_examples, 1) + input_mean 39 target = inputs - input_mean 40 examples = [] 41 for row in range(num_examples): 42 ex = example_pb2.Example() 43 ex.features.feature["id"].bytes_list.value.append(bytes(ids[row, 0])) 44 ex.features.feature["target"].float_list.value.append(target[row, 0]) 45 ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) 46 examples.append(ex) 47 return examples 48 49 50def create_dir_test(): 51 """Verifies file_io directory handling methods.""" 52 53 # Test directory creation. 54 starttime_ms = int(round(time.time() * 1000)) 55 dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) 56 print("Creating dir %s" % dir_name) 57 file_io.create_dir(dir_name) 58 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 59 print("Created directory in: %d milliseconds" % elapsed_ms) 60 61 # Check that the directory exists. 62 dir_exists = file_io.is_directory(dir_name) 63 assert dir_exists 64 print("%s directory exists: %s" % (dir_name, dir_exists)) 65 66 # Test recursive directory creation. 67 starttime_ms = int(round(time.time() * 1000)) 68 recursive_dir_name = "%s/%s/%s" % (dir_name, 69 "nested_dir1", 70 "nested_dir2") 71 print("Creating recursive dir %s" % recursive_dir_name) 72 file_io.recursive_create_dir(recursive_dir_name) 73 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 74 print("Created directory recursively in: %d milliseconds" % elapsed_ms) 75 76 # Check that the directory exists. 77 recursive_dir_exists = file_io.is_directory(recursive_dir_name) 78 assert recursive_dir_exists 79 print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists)) 80 81 # Create some contents in the just created directory and list the contents. 82 num_files = 10 83 files_to_create = ["file_%d.txt" % n for n in range(num_files)] 84 for file_num in files_to_create: 85 file_name = "%s/%s" % (dir_name, file_num) 86 print("Creating file %s." % file_name) 87 file_io.write_string_to_file(file_name, "test file.") 88 89 print("Listing directory %s." % dir_name) 90 starttime_ms = int(round(time.time() * 1000)) 91 directory_contents = file_io.list_directory(dir_name) 92 print(directory_contents) 93 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 94 print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms)) 95 assert set(directory_contents) == set(files_to_create + ["nested_dir1/"]) 96 97 # Test directory renaming. 98 dir_to_rename = "%s/old_dir" % dir_name 99 new_dir_name = "%s/new_dir" % dir_name 100 file_io.create_dir(dir_to_rename) 101 assert file_io.is_directory(dir_to_rename) 102 assert not file_io.is_directory(new_dir_name) 103 104 starttime_ms = int(round(time.time() * 1000)) 105 print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name)) 106 file_io.rename(dir_to_rename, new_dir_name) 107 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 108 print("Renamed directory %s to %s in %s milliseconds" % ( 109 dir_to_rename, new_dir_name, elapsed_ms)) 110 assert not file_io.is_directory(dir_to_rename) 111 assert file_io.is_directory(new_dir_name) 112 113 # Test Delete directory recursively. 114 print("Deleting directory recursively %s." % dir_name) 115 starttime_ms = int(round(time.time() * 1000)) 116 file_io.delete_recursively(dir_name) 117 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 118 dir_exists = file_io.is_directory(dir_name) 119 assert not dir_exists 120 print("Deleted directory recursively %s in %s milliseconds" % ( 121 dir_name, elapsed_ms)) 122 123 124def create_object_test(): 125 """Verifies file_io's object manipulation methods .""" 126 starttime_ms = int(round(time.time() * 1000)) 127 dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) 128 print("Creating dir %s." % dir_name) 129 file_io.create_dir(dir_name) 130 131 num_files = 5 132 # Create files of 2 different patterns in this directory. 133 files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n) 134 for n in range(num_files)] 135 files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n) 136 for n in range(num_files)] 137 138 starttime_ms = int(round(time.time() * 1000)) 139 files_to_create = files_pattern_1 + files_pattern_2 140 for file_name in files_to_create: 141 print("Creating file %s." % file_name) 142 file_io.write_string_to_file(file_name, "test file creation.") 143 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 144 print("Created %d files in %s milliseconds" % ( 145 len(files_to_create), elapsed_ms)) 146 147 # Listing files of pattern1. 148 list_files_pattern = "%s/test_file*.txt" % dir_name 149 print("Getting files matching pattern %s." % list_files_pattern) 150 starttime_ms = int(round(time.time() * 1000)) 151 files_list = file_io.get_matching_files(list_files_pattern) 152 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 153 print("Listed files in %s milliseconds" % elapsed_ms) 154 print(files_list) 155 assert set(files_list) == set(files_pattern_1) 156 157 # Listing files of pattern2. 158 list_files_pattern = "%s/testfile*.txt" % dir_name 159 print("Getting files matching pattern %s." % list_files_pattern) 160 starttime_ms = int(round(time.time() * 1000)) 161 files_list = file_io.get_matching_files(list_files_pattern) 162 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 163 print("Listed files in %s milliseconds" % elapsed_ms) 164 print(files_list) 165 assert set(files_list) == set(files_pattern_2) 166 167 # Test renaming file. 168 file_to_rename = "%s/oldname.txt" % dir_name 169 file_new_name = "%s/newname.txt" % dir_name 170 file_io.write_string_to_file(file_to_rename, "test file.") 171 assert file_io.file_exists(file_to_rename) 172 assert not file_io.file_exists(file_new_name) 173 174 print("Will try renaming file %s to %s" % (file_to_rename, file_new_name)) 175 starttime_ms = int(round(time.time() * 1000)) 176 file_io.rename(file_to_rename, file_new_name) 177 elapsed_ms = int(round(time.time() * 1000)) - starttime_ms 178 print("File %s renamed to %s in %s milliseconds" % ( 179 file_to_rename, file_new_name, elapsed_ms)) 180 assert not file_io.file_exists(file_to_rename) 181 assert file_io.file_exists(file_new_name) 182 183 # Delete directory. 184 print("Deleting directory %s." % dir_name) 185 file_io.delete_recursively(dir_name) 186 187 188def main(argv): 189 del argv # Unused. 190 191 # Sanity check on the GCS bucket URL. 192 if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): 193 print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) 194 sys.exit(1) 195 196 # Generate random tfrecord path name. 197 input_path = FLAGS.gcs_bucket_url + "/" 198 input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8)) 199 input_path += ".tfrecord" 200 print("Using input path: %s" % input_path) 201 202 # Verify that writing to the records file in GCS works. 203 print("\n=== Testing writing and reading of GCS record file... ===") 204 example_data = create_examples(FLAGS.num_examples, 5) 205 with tf.io.TFRecordWriter(input_path) as hf: 206 for e in example_data: 207 hf.write(e.SerializeToString()) 208 209 print("Data written to: %s" % input_path) 210 211 # Verify that reading from the tfrecord file works and that 212 # tf_record_iterator works. 213 record_iter = tf.compat.v1.python_io.tf_record_iterator(input_path) 214 read_count = 0 215 for _ in record_iter: 216 read_count += 1 217 print("Read %d records using tf_record_iterator" % read_count) 218 219 if read_count != FLAGS.num_examples: 220 print("FAIL: The number of records read from tf_record_iterator (%d) " 221 "differs from the expected number (%d)" % (read_count, 222 FLAGS.num_examples)) 223 sys.exit(1) 224 225 # Verify that running the read op in a session works. 226 print("\n=== Testing TFRecordReader.read op in a session... ===") 227 with tf.Graph().as_default(): 228 filename_queue = tf.compat.v1.train.string_input_producer([input_path], 229 num_epochs=1) 230 reader = tf.compat.v1.TFRecordReader() 231 _, serialized_example = reader.read(filename_queue) 232 233 with tf.compat.v1.Session() as sess: 234 sess.run(tf.compat.v1.global_variables_initializer()) 235 sess.run(tf.compat.v1.local_variables_initializer()) 236 tf.compat.v1.train.start_queue_runners() 237 index = 0 238 for _ in range(FLAGS.num_examples): 239 print("Read record: %d" % index) 240 sess.run(serialized_example) 241 index += 1 242 243 # Reading one more record should trigger an exception. 244 try: 245 sess.run(serialized_example) 246 print("FAIL: Failed to catch the expected OutOfRangeError while " 247 "reading one more record than is available") 248 sys.exit(1) 249 except tf.errors.OutOfRangeError: 250 print("Successfully caught the expected OutOfRangeError while " 251 "reading one more record than is available") 252 253 create_dir_test() 254 create_object_test() 255 256 257if __name__ == "__main__": 258 tf.compat.v1.app.run(main) 259