xref: /aosp_15_r20/external/tensorflow/tensorflow/python/platform/benchmark.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities to run benchmarks."""
17import math
18import numbers
19import os
20import re
21import sys
22import time
23import types
24
25from absl import app
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.core.protobuf import rewriter_config_pb2
29from tensorflow.core.util import test_log_pb2
30from tensorflow.python.client import timeline
31from tensorflow.python.framework import ops
32from tensorflow.python.platform import gfile
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util.tf_export import tf_export
36
37
38# When a subclass of the Benchmark class is created, it is added to
39# the registry automatically
40GLOBAL_BENCHMARK_REGISTRY = set()
41
42# Environment variable that determines whether benchmarks are written.
43# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
44TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
45
46# Environment variable that lets the TensorFlow runtime allocate a new
47# threadpool for each benchmark.
48OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL"
49
50
51def _rename_function(f, arg_num, name):
52  """Rename the given function's name appears in the stack trace."""
53  func_code = f.__code__
54  if sys.version_info > (3, 8, 0, "alpha", 3):
55    # Python3.8 / PEP570 added co_posonlyargcount argument to CodeType.
56    new_code = types.CodeType(
57        arg_num, func_code.co_posonlyargcount, 0, func_code.co_nlocals,
58        func_code.co_stacksize, func_code.co_flags, func_code.co_code,
59        func_code.co_consts, func_code.co_names, func_code.co_varnames,
60        func_code.co_filename, name, func_code.co_firstlineno,
61        func_code.co_lnotab, func_code.co_freevars, func_code.co_cellvars)
62  else:
63    new_code = types.CodeType(arg_num, 0, func_code.co_nlocals,
64                              func_code.co_stacksize, func_code.co_flags,
65                              func_code.co_code, func_code.co_consts,
66                              func_code.co_names, func_code.co_varnames,
67                              func_code.co_filename, name,
68                              func_code.co_firstlineno, func_code.co_lnotab,
69                              func_code.co_freevars, func_code.co_cellvars)
70
71  return types.FunctionType(new_code, f.__globals__, name, f.__defaults__,
72                            f.__closure__)
73
74
75def _global_report_benchmark(
76    name, iters=None, cpu_time=None, wall_time=None,
77    throughput=None, extras=None, metrics=None):
78  """Method for recording a benchmark directly.
79
80  Args:
81    name: The BenchmarkEntry name.
82    iters: (optional) How many iterations were run
83    cpu_time: (optional) Total cpu time in seconds
84    wall_time: (optional) Total wall time in seconds
85    throughput: (optional) Throughput (in MB/s)
86    extras: (optional) Dict mapping string keys to additional benchmark info.
87    metrics: (optional) A list of dict representing metrics generated by the
88      benchmark. Each dict should contain keys 'name' and'value'. A dict
89      can optionally contain keys 'min_value' and 'max_value'.
90
91  Raises:
92    TypeError: if extras is not a dict.
93    IOError: if the benchmark output file already exists.
94  """
95  logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
96               "throughput: %g, extras: %s, metrics: %s", name,
97               iters if iters is not None else -1,
98               wall_time if wall_time is not None else -1,
99               cpu_time if cpu_time is not None else -1,
100               throughput if throughput is not None else -1,
101               str(extras) if extras else "None",
102               str(metrics) if metrics else "None")
103
104  entries = test_log_pb2.BenchmarkEntries()
105  entry = entries.entry.add()
106  entry.name = name
107  if iters is not None:
108    entry.iters = iters
109  if cpu_time is not None:
110    entry.cpu_time = cpu_time
111  if wall_time is not None:
112    entry.wall_time = wall_time
113  if throughput is not None:
114    entry.throughput = throughput
115  if extras is not None:
116    if not isinstance(extras, dict):
117      raise TypeError("extras must be a dict")
118    for (k, v) in extras.items():
119      if isinstance(v, numbers.Number):
120        entry.extras[k].double_value = v
121      else:
122        entry.extras[k].string_value = str(v)
123  if metrics is not None:
124    if not isinstance(metrics, list):
125      raise TypeError("metrics must be a list")
126    for metric in metrics:
127      if "name" not in metric:
128        raise TypeError("metric must has a 'name' field")
129      if "value" not in metric:
130        raise TypeError("metric must has a 'value' field")
131
132      metric_entry = entry.metrics.add()
133      metric_entry.name = metric["name"]
134      metric_entry.value = metric["value"]
135      if "min_value" in metric:
136        metric_entry.min_value.value = metric["min_value"]
137      if "max_value" in metric:
138        metric_entry.max_value.value = metric["max_value"]
139
140  test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
141  if test_env is None:
142    # Reporting was not requested, just print the proto
143    print(str(entries))
144    return
145
146  serialized_entry = entries.SerializeToString()
147
148  mangled_name = name.replace("/", "__")
149  output_path = "%s%s" % (test_env, mangled_name)
150  if gfile.Exists(output_path):
151    raise IOError("File already exists: %s" % output_path)
152  with gfile.GFile(output_path, "wb") as out:
153    out.write(serialized_entry)
154
155
156class _BenchmarkRegistrar(type):
157  """The Benchmark class registrar.  Used by abstract Benchmark class."""
158
159  def __new__(mcs, clsname, base, attrs):
160    newclass = type.__new__(mcs, clsname, base, attrs)
161    if not newclass.is_abstract():
162      GLOBAL_BENCHMARK_REGISTRY.add(newclass)
163    return newclass
164
165
166@tf_export("__internal__.test.ParameterizedBenchmark", v1=[])
167class ParameterizedBenchmark(_BenchmarkRegistrar):
168  """Metaclass to generate parameterized benchmarks.
169
170  Use this class as a metaclass and override the `_benchmark_parameters` to
171  generate multiple benchmark test cases. For example:
172
173  class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark,
174                     tf.test.Benchmark):
175    # The `_benchmark_parameters` is expected to be a list with test cases.
176    # Each of the test case is a tuple, with the first time to be test case
177    # name, followed by any number of the parameters needed for the test case.
178    _benchmark_parameters = [
179      ('case_1', Foo, 1, 'one'),
180      ('case_2', Bar, 2, 'two'),
181    ]
182
183    def benchmark_test(self, target_class, int_param, string_param):
184      # benchmark test body
185
186  The example above will generate two benchmark test cases:
187  "benchmark_test__case_1" and "benchmark_test__case_2".
188  """
189
190  def __new__(mcs, clsname, base, attrs):
191    param_config_list = attrs["_benchmark_parameters"]
192
193    def create_benchmark_function(original_benchmark, params):
194      return lambda self: original_benchmark(self, *params)
195
196    for name in attrs.copy().keys():
197      if not name.startswith("benchmark"):
198        continue
199
200      original_benchmark = attrs[name]
201      del attrs[name]
202
203      for param_config in param_config_list:
204        test_name_suffix = param_config[0]
205        params = param_config[1:]
206        benchmark_name = name + "__" + test_name_suffix
207        if benchmark_name in attrs:
208          raise Exception(
209              "Benchmark named {} already defined.".format(benchmark_name))
210
211        benchmark = create_benchmark_function(original_benchmark, params)
212        # Renaming is important because `report_benchmark` function looks up the
213        # function name in the stack trace.
214        attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name)
215
216    return super().__new__(mcs, clsname, base, attrs)
217
218
219class Benchmark(metaclass=_BenchmarkRegistrar):
220  """Abstract class that provides helper functions for running benchmarks.
221
222  Any class subclassing this one is immediately registered in the global
223  benchmark registry.
224
225  Only methods whose names start with the word "benchmark" will be run during
226  benchmarking.
227  """
228
229  @classmethod
230  def is_abstract(cls):
231    # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark
232    return len(cls.mro()) <= 2
233
234  def _get_name(self, overwrite_name=None):
235    """Returns full name of class and method calling report_benchmark."""
236
237    # Find the caller method (outermost Benchmark class)
238    stack = tf_inspect.stack()
239    calling_class = None
240    name = None
241    for frame in stack[::-1]:
242      f_locals = frame[0].f_locals
243      f_self = f_locals.get("self", None)
244      if isinstance(f_self, Benchmark):
245        calling_class = f_self  # Get the outermost stack Benchmark call
246        name = frame[3]  # Get the method name
247        break
248    if calling_class is None:
249      raise ValueError("Unable to determine calling Benchmark class.")
250
251    # Use the method name, or overwrite_name is provided.
252    name = overwrite_name or name
253    # Prefix the name with the class name.
254    class_name = type(calling_class).__name__
255    name = "%s.%s" % (class_name, name)
256    return name
257
258  def report_benchmark(
259      self,
260      iters=None,
261      cpu_time=None,
262      wall_time=None,
263      throughput=None,
264      extras=None,
265      name=None,
266      metrics=None):
267    """Report a benchmark.
268
269    Args:
270      iters: (optional) How many iterations were run
271      cpu_time: (optional) Median or mean cpu time in seconds.
272      wall_time: (optional) Median or mean wall time in seconds.
273      throughput: (optional) Throughput (in MB/s)
274      extras: (optional) Dict mapping string keys to additional benchmark info.
275        Values may be either floats or values that are convertible to strings.
276      name: (optional) Override the BenchmarkEntry name with `name`.
277        Otherwise it is inferred from the top-level method name.
278      metrics: (optional) A list of dict, where each dict has the keys below
279        name (required), string, metric name
280        value (required), double, metric value
281        min_value (optional), double, minimum acceptable metric value
282        max_value (optional), double, maximum acceptable metric value
283    """
284    name = self._get_name(overwrite_name=name)
285    _global_report_benchmark(
286        name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
287        throughput=throughput, extras=extras, metrics=metrics)
288
289
290@tf_export("test.benchmark_config")
291def benchmark_config():
292  """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer.
293
294    Returns:
295      A TensorFlow ConfigProto object.
296  """
297  config = config_pb2.ConfigProto()
298  config.graph_options.rewrite_options.dependency_optimization = (
299      rewriter_config_pb2.RewriterConfig.OFF)
300  return config
301
302
303@tf_export("test.Benchmark")
304class TensorFlowBenchmark(Benchmark):
305  """Abstract class that provides helpers for TensorFlow benchmarks."""
306
307  def __init__(self):
308    # Allow TensorFlow runtime to allocate a new threadpool with different
309    # number of threads for each new benchmark.
310    os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1"
311    super().__init__()
312
313  @classmethod
314  def is_abstract(cls):
315    # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means
316    # this is TensorFlowBenchmark.
317    return len(cls.mro()) <= 3
318
319  def run_op_benchmark(self,
320                       sess,
321                       op_or_tensor,
322                       feed_dict=None,
323                       burn_iters=2,
324                       min_iters=10,
325                       store_trace=False,
326                       store_memory_usage=True,
327                       name=None,
328                       extras=None,
329                       mbs=0):
330    """Run an op or tensor in the given session.  Report the results.
331
332    Args:
333      sess: `Session` object to use for timing.
334      op_or_tensor: `Operation` or `Tensor` to benchmark.
335      feed_dict: A `dict` of values to feed for each op iteration (see the
336        `feed_dict` parameter of `Session.run`).
337      burn_iters: Number of burn-in iterations to run.
338      min_iters: Minimum number of iterations to use for timing.
339      store_trace: Boolean, whether to run an extra untimed iteration and
340        store the trace of iteration in returned extras.
341        The trace will be stored as a string in Google Chrome trace format
342        in the extras field "full_trace_chrome_format". Note that trace
343        will not be stored in test_log_pb2.TestResults proto.
344      store_memory_usage: Boolean, whether to run an extra untimed iteration,
345        calculate memory usage, and store that in extras fields.
346      name: (optional) Override the BenchmarkEntry name with `name`.
347        Otherwise it is inferred from the top-level method name.
348      extras: (optional) Dict mapping string keys to additional benchmark info.
349        Values may be either floats or values that are convertible to strings.
350      mbs: (optional) The number of megabytes moved by this op, used to
351        calculate the ops throughput.
352
353    Returns:
354      A `dict` containing the key-value pairs that were passed to
355      `report_benchmark`. If `store_trace` option is used, then
356      `full_chrome_trace_format` will be included in return dictionary even
357      though it is not passed to `report_benchmark` with `extras`.
358    """
359    for _ in range(burn_iters):
360      sess.run(op_or_tensor, feed_dict=feed_dict)
361
362    deltas = [None] * min_iters
363
364    for i in range(min_iters):
365      start_time = time.time()
366      sess.run(op_or_tensor, feed_dict=feed_dict)
367      end_time = time.time()
368      delta = end_time - start_time
369      deltas[i] = delta
370
371    extras = extras if extras is not None else {}
372    unreported_extras = {}
373    if store_trace or store_memory_usage:
374      run_options = config_pb2.RunOptions(
375          trace_level=config_pb2.RunOptions.FULL_TRACE)
376      run_metadata = config_pb2.RunMetadata()
377      sess.run(op_or_tensor, feed_dict=feed_dict,
378               options=run_options, run_metadata=run_metadata)
379      tl = timeline.Timeline(run_metadata.step_stats)
380
381      if store_trace:
382        unreported_extras["full_trace_chrome_format"] = (
383            tl.generate_chrome_trace_format())
384
385      if store_memory_usage:
386        step_stats_analysis = tl.analyze_step_stats(show_memory=True)
387        allocator_maximums = step_stats_analysis.allocator_maximums
388        for k, v in allocator_maximums.items():
389          extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes
390
391    def _median(x):
392      if not x:
393        return -1
394      s = sorted(x)
395      l = len(x)
396      lm1 = l - 1
397      return (s[l//2] + s[lm1//2]) / 2.0
398
399    def _mean_and_stdev(x):
400      if not x:
401        return -1, -1
402      l = len(x)
403      mean = sum(x) / l
404      if l == 1:
405        return mean, -1
406      variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1)
407      return mean, math.sqrt(variance)
408
409    median_delta = _median(deltas)
410
411    benchmark_values = {
412        "iters": min_iters,
413        "wall_time": median_delta,
414        "extras": extras,
415        "name": name,
416        "throughput": mbs / median_delta
417    }
418    self.report_benchmark(**benchmark_values)
419
420    mean_delta, stdev_delta = _mean_and_stdev(deltas)
421    unreported_extras["wall_time_mean"] = mean_delta
422    unreported_extras["wall_time_stdev"] = stdev_delta
423    benchmark_values["extras"].update(unreported_extras)
424    return benchmark_values
425
426  def evaluate(self, tensors):
427    """Evaluates tensors and returns numpy values.
428
429    Args:
430      tensors: A Tensor or a nested list/tuple of Tensors.
431
432    Returns:
433      tensors numpy values.
434    """
435    sess = ops.get_default_session() or self.cached_session()
436    return sess.run(tensors)
437
438
439def _run_benchmarks(regex):
440  """Run benchmarks that match regex `regex`.
441
442  This function goes through the global benchmark registry, and matches
443  benchmark class and method names of the form
444  `module.name.BenchmarkClass.benchmarkMethod` to the given regex.
445  If a method matches, it is run.
446
447  Args:
448    regex: The string regular expression to match Benchmark classes against.
449
450  Raises:
451    ValueError: If no benchmarks were selected by the input regex.
452  """
453  registry = list(GLOBAL_BENCHMARK_REGISTRY)
454
455  selected_benchmarks = []
456  # Match benchmarks in registry against regex
457  for benchmark in registry:
458    benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
459    attrs = dir(benchmark)
460    # Don't instantiate the benchmark class unless necessary
461    benchmark_instance = None
462
463    for attr in attrs:
464      if not attr.startswith("benchmark"):
465        continue
466      candidate_benchmark_fn = getattr(benchmark, attr)
467      if not callable(candidate_benchmark_fn):
468        continue
469      full_benchmark_name = "%s.%s" % (benchmark_name, attr)
470      if regex == "all" or re.search(regex, full_benchmark_name):
471        selected_benchmarks.append(full_benchmark_name)
472        # Instantiate the class if it hasn't been instantiated
473        benchmark_instance = benchmark_instance or benchmark()
474        # Get the method tied to the class
475        instance_benchmark_fn = getattr(benchmark_instance, attr)
476        # Call the instance method
477        instance_benchmark_fn()
478
479  if not selected_benchmarks:
480    raise ValueError("No benchmarks matched the pattern: '{}'".format(regex))
481
482
483def benchmarks_main(true_main, argv=None):
484  """Run benchmarks as declared in argv.
485
486  Args:
487    true_main: True main function to run if benchmarks are not requested.
488    argv: the command line arguments (if None, uses sys.argv).
489  """
490  if argv is None:
491    argv = sys.argv
492  found_arg = [arg for arg in argv
493               if arg.startswith("--benchmarks=")
494               or arg.startswith("-benchmarks=")]
495  if found_arg:
496    # Remove --benchmarks arg from sys.argv
497    argv.remove(found_arg[0])
498
499    regex = found_arg[0].split("=")[1]
500    app.run(lambda _: _run_benchmarks(regex), argv=argv)
501  else:
502    true_main()
503