1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities to run benchmarks.""" 17import math 18import numbers 19import os 20import re 21import sys 22import time 23import types 24 25from absl import app 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.core.protobuf import rewriter_config_pb2 29from tensorflow.core.util import test_log_pb2 30from tensorflow.python.client import timeline 31from tensorflow.python.framework import ops 32from tensorflow.python.platform import gfile 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.util import tf_inspect 35from tensorflow.python.util.tf_export import tf_export 36 37 38# When a subclass of the Benchmark class is created, it is added to 39# the registry automatically 40GLOBAL_BENCHMARK_REGISTRY = set() 41 42# Environment variable that determines whether benchmarks are written. 43# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. 44TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" 45 46# Environment variable that lets the TensorFlow runtime allocate a new 47# threadpool for each benchmark. 48OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL" 49 50 51def _rename_function(f, arg_num, name): 52 """Rename the given function's name appears in the stack trace.""" 53 func_code = f.__code__ 54 if sys.version_info > (3, 8, 0, "alpha", 3): 55 # Python3.8 / PEP570 added co_posonlyargcount argument to CodeType. 56 new_code = types.CodeType( 57 arg_num, func_code.co_posonlyargcount, 0, func_code.co_nlocals, 58 func_code.co_stacksize, func_code.co_flags, func_code.co_code, 59 func_code.co_consts, func_code.co_names, func_code.co_varnames, 60 func_code.co_filename, name, func_code.co_firstlineno, 61 func_code.co_lnotab, func_code.co_freevars, func_code.co_cellvars) 62 else: 63 new_code = types.CodeType(arg_num, 0, func_code.co_nlocals, 64 func_code.co_stacksize, func_code.co_flags, 65 func_code.co_code, func_code.co_consts, 66 func_code.co_names, func_code.co_varnames, 67 func_code.co_filename, name, 68 func_code.co_firstlineno, func_code.co_lnotab, 69 func_code.co_freevars, func_code.co_cellvars) 70 71 return types.FunctionType(new_code, f.__globals__, name, f.__defaults__, 72 f.__closure__) 73 74 75def _global_report_benchmark( 76 name, iters=None, cpu_time=None, wall_time=None, 77 throughput=None, extras=None, metrics=None): 78 """Method for recording a benchmark directly. 79 80 Args: 81 name: The BenchmarkEntry name. 82 iters: (optional) How many iterations were run 83 cpu_time: (optional) Total cpu time in seconds 84 wall_time: (optional) Total wall time in seconds 85 throughput: (optional) Throughput (in MB/s) 86 extras: (optional) Dict mapping string keys to additional benchmark info. 87 metrics: (optional) A list of dict representing metrics generated by the 88 benchmark. Each dict should contain keys 'name' and'value'. A dict 89 can optionally contain keys 'min_value' and 'max_value'. 90 91 Raises: 92 TypeError: if extras is not a dict. 93 IOError: if the benchmark output file already exists. 94 """ 95 logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," 96 "throughput: %g, extras: %s, metrics: %s", name, 97 iters if iters is not None else -1, 98 wall_time if wall_time is not None else -1, 99 cpu_time if cpu_time is not None else -1, 100 throughput if throughput is not None else -1, 101 str(extras) if extras else "None", 102 str(metrics) if metrics else "None") 103 104 entries = test_log_pb2.BenchmarkEntries() 105 entry = entries.entry.add() 106 entry.name = name 107 if iters is not None: 108 entry.iters = iters 109 if cpu_time is not None: 110 entry.cpu_time = cpu_time 111 if wall_time is not None: 112 entry.wall_time = wall_time 113 if throughput is not None: 114 entry.throughput = throughput 115 if extras is not None: 116 if not isinstance(extras, dict): 117 raise TypeError("extras must be a dict") 118 for (k, v) in extras.items(): 119 if isinstance(v, numbers.Number): 120 entry.extras[k].double_value = v 121 else: 122 entry.extras[k].string_value = str(v) 123 if metrics is not None: 124 if not isinstance(metrics, list): 125 raise TypeError("metrics must be a list") 126 for metric in metrics: 127 if "name" not in metric: 128 raise TypeError("metric must has a 'name' field") 129 if "value" not in metric: 130 raise TypeError("metric must has a 'value' field") 131 132 metric_entry = entry.metrics.add() 133 metric_entry.name = metric["name"] 134 metric_entry.value = metric["value"] 135 if "min_value" in metric: 136 metric_entry.min_value.value = metric["min_value"] 137 if "max_value" in metric: 138 metric_entry.max_value.value = metric["max_value"] 139 140 test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) 141 if test_env is None: 142 # Reporting was not requested, just print the proto 143 print(str(entries)) 144 return 145 146 serialized_entry = entries.SerializeToString() 147 148 mangled_name = name.replace("/", "__") 149 output_path = "%s%s" % (test_env, mangled_name) 150 if gfile.Exists(output_path): 151 raise IOError("File already exists: %s" % output_path) 152 with gfile.GFile(output_path, "wb") as out: 153 out.write(serialized_entry) 154 155 156class _BenchmarkRegistrar(type): 157 """The Benchmark class registrar. Used by abstract Benchmark class.""" 158 159 def __new__(mcs, clsname, base, attrs): 160 newclass = type.__new__(mcs, clsname, base, attrs) 161 if not newclass.is_abstract(): 162 GLOBAL_BENCHMARK_REGISTRY.add(newclass) 163 return newclass 164 165 166@tf_export("__internal__.test.ParameterizedBenchmark", v1=[]) 167class ParameterizedBenchmark(_BenchmarkRegistrar): 168 """Metaclass to generate parameterized benchmarks. 169 170 Use this class as a metaclass and override the `_benchmark_parameters` to 171 generate multiple benchmark test cases. For example: 172 173 class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark, 174 tf.test.Benchmark): 175 # The `_benchmark_parameters` is expected to be a list with test cases. 176 # Each of the test case is a tuple, with the first time to be test case 177 # name, followed by any number of the parameters needed for the test case. 178 _benchmark_parameters = [ 179 ('case_1', Foo, 1, 'one'), 180 ('case_2', Bar, 2, 'two'), 181 ] 182 183 def benchmark_test(self, target_class, int_param, string_param): 184 # benchmark test body 185 186 The example above will generate two benchmark test cases: 187 "benchmark_test__case_1" and "benchmark_test__case_2". 188 """ 189 190 def __new__(mcs, clsname, base, attrs): 191 param_config_list = attrs["_benchmark_parameters"] 192 193 def create_benchmark_function(original_benchmark, params): 194 return lambda self: original_benchmark(self, *params) 195 196 for name in attrs.copy().keys(): 197 if not name.startswith("benchmark"): 198 continue 199 200 original_benchmark = attrs[name] 201 del attrs[name] 202 203 for param_config in param_config_list: 204 test_name_suffix = param_config[0] 205 params = param_config[1:] 206 benchmark_name = name + "__" + test_name_suffix 207 if benchmark_name in attrs: 208 raise Exception( 209 "Benchmark named {} already defined.".format(benchmark_name)) 210 211 benchmark = create_benchmark_function(original_benchmark, params) 212 # Renaming is important because `report_benchmark` function looks up the 213 # function name in the stack trace. 214 attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name) 215 216 return super().__new__(mcs, clsname, base, attrs) 217 218 219class Benchmark(metaclass=_BenchmarkRegistrar): 220 """Abstract class that provides helper functions for running benchmarks. 221 222 Any class subclassing this one is immediately registered in the global 223 benchmark registry. 224 225 Only methods whose names start with the word "benchmark" will be run during 226 benchmarking. 227 """ 228 229 @classmethod 230 def is_abstract(cls): 231 # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark 232 return len(cls.mro()) <= 2 233 234 def _get_name(self, overwrite_name=None): 235 """Returns full name of class and method calling report_benchmark.""" 236 237 # Find the caller method (outermost Benchmark class) 238 stack = tf_inspect.stack() 239 calling_class = None 240 name = None 241 for frame in stack[::-1]: 242 f_locals = frame[0].f_locals 243 f_self = f_locals.get("self", None) 244 if isinstance(f_self, Benchmark): 245 calling_class = f_self # Get the outermost stack Benchmark call 246 name = frame[3] # Get the method name 247 break 248 if calling_class is None: 249 raise ValueError("Unable to determine calling Benchmark class.") 250 251 # Use the method name, or overwrite_name is provided. 252 name = overwrite_name or name 253 # Prefix the name with the class name. 254 class_name = type(calling_class).__name__ 255 name = "%s.%s" % (class_name, name) 256 return name 257 258 def report_benchmark( 259 self, 260 iters=None, 261 cpu_time=None, 262 wall_time=None, 263 throughput=None, 264 extras=None, 265 name=None, 266 metrics=None): 267 """Report a benchmark. 268 269 Args: 270 iters: (optional) How many iterations were run 271 cpu_time: (optional) Median or mean cpu time in seconds. 272 wall_time: (optional) Median or mean wall time in seconds. 273 throughput: (optional) Throughput (in MB/s) 274 extras: (optional) Dict mapping string keys to additional benchmark info. 275 Values may be either floats or values that are convertible to strings. 276 name: (optional) Override the BenchmarkEntry name with `name`. 277 Otherwise it is inferred from the top-level method name. 278 metrics: (optional) A list of dict, where each dict has the keys below 279 name (required), string, metric name 280 value (required), double, metric value 281 min_value (optional), double, minimum acceptable metric value 282 max_value (optional), double, maximum acceptable metric value 283 """ 284 name = self._get_name(overwrite_name=name) 285 _global_report_benchmark( 286 name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time, 287 throughput=throughput, extras=extras, metrics=metrics) 288 289 290@tf_export("test.benchmark_config") 291def benchmark_config(): 292 """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer. 293 294 Returns: 295 A TensorFlow ConfigProto object. 296 """ 297 config = config_pb2.ConfigProto() 298 config.graph_options.rewrite_options.dependency_optimization = ( 299 rewriter_config_pb2.RewriterConfig.OFF) 300 return config 301 302 303@tf_export("test.Benchmark") 304class TensorFlowBenchmark(Benchmark): 305 """Abstract class that provides helpers for TensorFlow benchmarks.""" 306 307 def __init__(self): 308 # Allow TensorFlow runtime to allocate a new threadpool with different 309 # number of threads for each new benchmark. 310 os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1" 311 super().__init__() 312 313 @classmethod 314 def is_abstract(cls): 315 # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means 316 # this is TensorFlowBenchmark. 317 return len(cls.mro()) <= 3 318 319 def run_op_benchmark(self, 320 sess, 321 op_or_tensor, 322 feed_dict=None, 323 burn_iters=2, 324 min_iters=10, 325 store_trace=False, 326 store_memory_usage=True, 327 name=None, 328 extras=None, 329 mbs=0): 330 """Run an op or tensor in the given session. Report the results. 331 332 Args: 333 sess: `Session` object to use for timing. 334 op_or_tensor: `Operation` or `Tensor` to benchmark. 335 feed_dict: A `dict` of values to feed for each op iteration (see the 336 `feed_dict` parameter of `Session.run`). 337 burn_iters: Number of burn-in iterations to run. 338 min_iters: Minimum number of iterations to use for timing. 339 store_trace: Boolean, whether to run an extra untimed iteration and 340 store the trace of iteration in returned extras. 341 The trace will be stored as a string in Google Chrome trace format 342 in the extras field "full_trace_chrome_format". Note that trace 343 will not be stored in test_log_pb2.TestResults proto. 344 store_memory_usage: Boolean, whether to run an extra untimed iteration, 345 calculate memory usage, and store that in extras fields. 346 name: (optional) Override the BenchmarkEntry name with `name`. 347 Otherwise it is inferred from the top-level method name. 348 extras: (optional) Dict mapping string keys to additional benchmark info. 349 Values may be either floats or values that are convertible to strings. 350 mbs: (optional) The number of megabytes moved by this op, used to 351 calculate the ops throughput. 352 353 Returns: 354 A `dict` containing the key-value pairs that were passed to 355 `report_benchmark`. If `store_trace` option is used, then 356 `full_chrome_trace_format` will be included in return dictionary even 357 though it is not passed to `report_benchmark` with `extras`. 358 """ 359 for _ in range(burn_iters): 360 sess.run(op_or_tensor, feed_dict=feed_dict) 361 362 deltas = [None] * min_iters 363 364 for i in range(min_iters): 365 start_time = time.time() 366 sess.run(op_or_tensor, feed_dict=feed_dict) 367 end_time = time.time() 368 delta = end_time - start_time 369 deltas[i] = delta 370 371 extras = extras if extras is not None else {} 372 unreported_extras = {} 373 if store_trace or store_memory_usage: 374 run_options = config_pb2.RunOptions( 375 trace_level=config_pb2.RunOptions.FULL_TRACE) 376 run_metadata = config_pb2.RunMetadata() 377 sess.run(op_or_tensor, feed_dict=feed_dict, 378 options=run_options, run_metadata=run_metadata) 379 tl = timeline.Timeline(run_metadata.step_stats) 380 381 if store_trace: 382 unreported_extras["full_trace_chrome_format"] = ( 383 tl.generate_chrome_trace_format()) 384 385 if store_memory_usage: 386 step_stats_analysis = tl.analyze_step_stats(show_memory=True) 387 allocator_maximums = step_stats_analysis.allocator_maximums 388 for k, v in allocator_maximums.items(): 389 extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes 390 391 def _median(x): 392 if not x: 393 return -1 394 s = sorted(x) 395 l = len(x) 396 lm1 = l - 1 397 return (s[l//2] + s[lm1//2]) / 2.0 398 399 def _mean_and_stdev(x): 400 if not x: 401 return -1, -1 402 l = len(x) 403 mean = sum(x) / l 404 if l == 1: 405 return mean, -1 406 variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1) 407 return mean, math.sqrt(variance) 408 409 median_delta = _median(deltas) 410 411 benchmark_values = { 412 "iters": min_iters, 413 "wall_time": median_delta, 414 "extras": extras, 415 "name": name, 416 "throughput": mbs / median_delta 417 } 418 self.report_benchmark(**benchmark_values) 419 420 mean_delta, stdev_delta = _mean_and_stdev(deltas) 421 unreported_extras["wall_time_mean"] = mean_delta 422 unreported_extras["wall_time_stdev"] = stdev_delta 423 benchmark_values["extras"].update(unreported_extras) 424 return benchmark_values 425 426 def evaluate(self, tensors): 427 """Evaluates tensors and returns numpy values. 428 429 Args: 430 tensors: A Tensor or a nested list/tuple of Tensors. 431 432 Returns: 433 tensors numpy values. 434 """ 435 sess = ops.get_default_session() or self.cached_session() 436 return sess.run(tensors) 437 438 439def _run_benchmarks(regex): 440 """Run benchmarks that match regex `regex`. 441 442 This function goes through the global benchmark registry, and matches 443 benchmark class and method names of the form 444 `module.name.BenchmarkClass.benchmarkMethod` to the given regex. 445 If a method matches, it is run. 446 447 Args: 448 regex: The string regular expression to match Benchmark classes against. 449 450 Raises: 451 ValueError: If no benchmarks were selected by the input regex. 452 """ 453 registry = list(GLOBAL_BENCHMARK_REGISTRY) 454 455 selected_benchmarks = [] 456 # Match benchmarks in registry against regex 457 for benchmark in registry: 458 benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) 459 attrs = dir(benchmark) 460 # Don't instantiate the benchmark class unless necessary 461 benchmark_instance = None 462 463 for attr in attrs: 464 if not attr.startswith("benchmark"): 465 continue 466 candidate_benchmark_fn = getattr(benchmark, attr) 467 if not callable(candidate_benchmark_fn): 468 continue 469 full_benchmark_name = "%s.%s" % (benchmark_name, attr) 470 if regex == "all" or re.search(regex, full_benchmark_name): 471 selected_benchmarks.append(full_benchmark_name) 472 # Instantiate the class if it hasn't been instantiated 473 benchmark_instance = benchmark_instance or benchmark() 474 # Get the method tied to the class 475 instance_benchmark_fn = getattr(benchmark_instance, attr) 476 # Call the instance method 477 instance_benchmark_fn() 478 479 if not selected_benchmarks: 480 raise ValueError("No benchmarks matched the pattern: '{}'".format(regex)) 481 482 483def benchmarks_main(true_main, argv=None): 484 """Run benchmarks as declared in argv. 485 486 Args: 487 true_main: True main function to run if benchmarks are not requested. 488 argv: the command line arguments (if None, uses sys.argv). 489 """ 490 if argv is None: 491 argv = sys.argv 492 found_arg = [arg for arg in argv 493 if arg.startswith("--benchmarks=") 494 or arg.startswith("-benchmarks=")] 495 if found_arg: 496 # Remove --benchmarks arg from sys.argv 497 argv.remove(found_arg[0]) 498 499 regex = found_arg[0].split("=")[1] 500 app.run(lambda _: _run_benchmarks(regex), argv=argv) 501 else: 502 true_main() 503