1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test runner for TensorFlow tests.""" 16 17import os 18import shlex 19import sys 20import time 21 22from absl import app 23from absl import flags 24 25from google.protobuf import json_format 26from google.protobuf import text_format 27from tensorflow.core.util import test_log_pb2 28from tensorflow.python.platform import gfile 29from tensorflow.python.platform import test 30from tensorflow.python.platform import tf_logging 31from tensorflow.tools.test import run_and_gather_logs_lib 32 33# pylint: disable=g-import-not-at-top 34# pylint: disable=g-bad-import-order 35# pylint: disable=unused-import 36# Note: cpuinfo and psutil are not installed for you in the TensorFlow 37# OSS tree. They are installable via pip. 38try: 39 import cpuinfo 40 import psutil 41except ImportError as e: 42 tf_logging.error("\n\n\nERROR: Unable to import necessary library: {}. " 43 "Issuing a soft exit.\n\n\n".format(e)) 44 sys.exit(0) 45# pylint: enable=g-bad-import-order 46# pylint: enable=unused-import 47 48FLAGS = flags.FLAGS 49 50flags.DEFINE_string("name", "", """Benchmark target identifier.""") 51flags.DEFINE_string("test_name", "", """Test target to run.""") 52flags.DEFINE_multi_string( 53 "test_args", "", """\ 54Test arguments, space separated. May be specified more than once, in which case 55the args are all appended.""") 56flags.DEFINE_boolean("test_log_output_use_tmpdir", False, 57 "Whether to store the log output into tmpdir.") 58flags.DEFINE_string("benchmark_type", "", 59 """Benchmark type (BenchmarkType enum string).""") 60flags.DEFINE_string("compilation_mode", "", 61 """Mode used during this build (e.g. opt, dbg).""") 62flags.DEFINE_string("cc_flags", "", """CC flags used during this build.""") 63flags.DEFINE_string("test_log_output_dir", "", 64 """Directory for benchmark results output.""") 65flags.DEFINE_string( 66 "test_log_output_filename", "", 67 """Filename to write output benchmark results to. If the filename 68 is not specified, it will be automatically created.""") 69flags.DEFINE_boolean("skip_export", False, 70 "Whether to skip exporting test results.") 71 72 73def gather_build_configuration(): 74 build_config = test_log_pb2.BuildConfiguration() 75 build_config.mode = FLAGS.compilation_mode 76 # Include all flags except includes 77 cc_flags = [ 78 flag for flag in shlex.split(FLAGS.cc_flags) if not flag.startswith("-i") 79 ] 80 build_config.cc_flags.extend(cc_flags) 81 return build_config 82 83 84def main(unused_args): 85 name = FLAGS.name 86 test_name = FLAGS.test_name 87 test_args = " ".join(FLAGS.test_args) 88 benchmark_type = FLAGS.benchmark_type 89 test_results, _ = run_and_gather_logs_lib.run_and_gather_logs( 90 name, 91 test_name=test_name, 92 test_args=test_args, 93 benchmark_type=benchmark_type, 94 skip_processing_logs=FLAGS.skip_export) 95 if FLAGS.skip_export: 96 return 97 98 # Additional bits we receive from bazel 99 test_results.build_configuration.CopyFrom(gather_build_configuration()) 100 # Add os.environ data to test_results. 101 test_results.run_configuration.env_vars.update(os.environ) 102 103 if not FLAGS.test_log_output_dir: 104 print(text_format.MessageToString(test_results)) 105 return 106 107 if FLAGS.test_log_output_filename: 108 file_name = FLAGS.test_log_output_filename 109 else: 110 file_name = ( 111 name.strip("/").translate(str.maketrans("/:", "__")) + 112 time.strftime("%Y%m%d%H%M%S", time.gmtime())) 113 if FLAGS.test_log_output_use_tmpdir: 114 tmpdir = test.get_temp_dir() 115 output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name) 116 else: 117 output_path = os.path.join( 118 os.path.abspath(FLAGS.test_log_output_dir), file_name) 119 json_test_results = json_format.MessageToJson(test_results) 120 gfile.GFile(output_path + ".json", "w").write(json_test_results) 121 tf_logging.info("Test results written to: %s" % output_path) 122 123 124if __name__ == "__main__": 125 app.run(main) 126