xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/test/run_and_gather_logs_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for getting system information during TensorFlow tests."""
16
17import os
18import re
19import shlex
20import subprocess
21import tempfile
22import time
23
24from tensorflow.core.util import test_log_pb2
25from tensorflow.python.platform import gfile
26from tensorflow.tools.test import gpu_info_lib
27from tensorflow.tools.test import system_info_lib
28
29
30class MissingLogsError(Exception):
31  pass
32
33
34def get_git_commit_sha():
35  """Get git commit SHA for this build.
36
37  Attempt to get the SHA from environment variable GIT_COMMIT, which should
38  be available on Jenkins build agents.
39
40  Returns:
41    SHA hash of the git commit used for the build, if available
42  """
43
44  return os.getenv("GIT_COMMIT")
45
46
47def process_test_logs(name, test_name, test_args, benchmark_type,
48                      start_time, run_time, log_files):
49  """Gather test information and put it in a TestResults proto.
50
51  Args:
52    name: Benchmark target identifier.
53    test_name: A unique bazel target, e.g. "//path/to:test"
54    test_args: A string containing all arguments to run the target with.
55    benchmark_type: A string representing the BenchmarkType enum; the
56      benchmark type for this target.
57    start_time: Test starting time (epoch)
58    run_time:   Wall time that the test ran for
59    log_files:  Paths to the log files
60
61  Returns:
62    A TestResults proto
63  """
64
65  results = test_log_pb2.TestResults()
66  results.name = name
67  results.target = test_name
68  results.start_time = start_time
69  results.run_time = run_time
70  results.benchmark_type = test_log_pb2.TestResults.BenchmarkType.Value(
71      benchmark_type.upper())
72
73  # Gather source code information
74  git_sha = get_git_commit_sha()
75  if git_sha:
76    results.commit_id.hash = git_sha
77
78  results.entries.CopyFrom(process_benchmarks(log_files))
79  results.run_configuration.argument.extend(test_args)
80  results.machine_configuration.CopyFrom(
81      system_info_lib.gather_machine_configuration())
82  return results
83
84
85def process_benchmarks(log_files):
86  benchmarks = test_log_pb2.BenchmarkEntries()
87  for f in log_files:
88    content = gfile.GFile(f, "rb").read()
89    if benchmarks.MergeFromString(content) != len(content):
90      raise Exception("Failed parsing benchmark entry from %s" % f)
91  return benchmarks
92
93
94def run_and_gather_logs(name,
95                        test_name,
96                        test_args,
97                        benchmark_type,
98                        skip_processing_logs=False):
99  """Run the bazel test given by test_name.  Gather and return the logs.
100
101  Args:
102    name: Benchmark target identifier.
103    test_name: A unique bazel target, e.g. "//path/to:test"
104    test_args: A string containing all arguments to run the target with.
105    benchmark_type: A string representing the BenchmarkType enum; the
106      benchmark type for this target.
107    skip_processing_logs: Whether to skip processing test results from log
108      files.
109
110  Returns:
111    A tuple (test_results, mangled_test_name), where
112    test_results: A test_log_pb2.TestResults proto, or None if log processing
113      is skipped.
114    test_adjusted_name: Unique benchmark name that consists of
115      benchmark name optionally followed by GPU type.
116
117  Raises:
118    ValueError: If the test_name is not a valid target.
119    subprocess.CalledProcessError: If the target itself fails.
120    IOError: If there are problems gathering test log output from the test.
121    MissingLogsError: If we couldn't find benchmark logs.
122  """
123  if not (test_name and test_name.startswith("//") and ".." not in test_name and
124          not test_name.endswith(":") and not test_name.endswith(":all") and
125          not test_name.endswith("...") and len(test_name.split(":")) == 2):
126    raise ValueError("Expected test_name parameter with a unique test, e.g.: "
127                     "--test_name=//path/to:test")
128  test_executable = test_name.rstrip().strip("/").replace(":", "/")
129
130  if gfile.Exists(os.path.join("bazel-bin", test_executable)):
131    # Running in standalone mode from core of the repository
132    test_executable = os.path.join("bazel-bin", test_executable)
133  else:
134    # Hopefully running in sandboxed mode
135    test_executable = os.path.join(".", test_executable)
136
137  test_adjusted_name = name
138  gpu_config = gpu_info_lib.gather_gpu_devices()
139  if gpu_config:
140    gpu_name = gpu_config[0].model
141    gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name)
142    if gpu_short_name_match:
143      gpu_short_name = gpu_short_name_match.group(0)
144      test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")
145
146  temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs")
147  mangled_test_name = (
148      test_adjusted_name.strip("/").replace("|",
149                                            "_").replace("/",
150                                                         "_").replace(":", "_"))
151  test_file_prefix = os.path.join(temp_directory, mangled_test_name)
152  test_file_prefix = "%s." % test_file_prefix
153
154  try:
155    if not gfile.Exists(test_executable):
156      test_executable_py3 = test_executable + ".python3"
157      if not gfile.Exists(test_executable_py3):
158        raise ValueError("Executable does not exist: %s" % test_executable)
159      test_executable = test_executable_py3
160    test_args = shlex.split(test_args)
161
162    # This key is defined in tf/core/util/reporter.h as
163    # TestReporter::kTestReporterEnv.
164    os.environ["TEST_REPORT_FILE_PREFIX"] = test_file_prefix
165    start_time = time.time()
166    subprocess.check_call([test_executable] + test_args)
167    if skip_processing_logs:
168      return None, test_adjusted_name
169    run_time = time.time() - start_time
170    log_files = gfile.Glob("{}*".format(test_file_prefix))
171    if not log_files:
172      raise MissingLogsError("No log files found at %s." % test_file_prefix)
173
174    return (process_test_logs(
175        test_adjusted_name,
176        test_name=test_name,
177        test_args=test_args,
178        benchmark_type=benchmark_type,
179        start_time=int(start_time),
180        run_time=run_time,
181        log_files=log_files), test_adjusted_name)
182
183  finally:
184    try:
185      gfile.DeleteRecursively(temp_directory)
186    except OSError:
187      pass
188