xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/test/performance.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1load("//tensorflow:tensorflow.bzl", "tf_py_test")
2
3# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test)
4def tf_cc_logged_benchmark(
5        name = None,
6        target = None,
7        benchmarks = "..",
8        tags = [],
9        benchmark_type = "cpp_microbenchmark",
10        **kwargs):
11    if not name:
12        fail("Must provide a name")
13    if not target:
14        fail("Must provide a target")
15    if (not ":" in target or
16        not target.startswith("//") or
17        target.endswith(":all") or
18        target.endswith(".")):
19        fail(" ".join((
20            "Target must be a single well-defined test, e.g.,",
21            "//path/to:test. Received: %s" % target,
22        )))
23
24    all_tags = tags + ["benchmark-test", "local", "manual", "regression-test"]
25
26    tf_py_test(
27        name = name,
28        tags = all_tags,
29        size = "large",
30        srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
31        args = [
32            "--name=//%s:%s" % (native.package_name(), name),
33            "--test_name=" + target,
34            "--test_args=--benchmarks=%s" % benchmarks,
35            "--benchmark_type=%s" % benchmark_type,
36        ],
37        data = [
38            target,
39        ],
40        main = "run_and_gather_logs.py",
41        deps = [
42            "//tensorflow/tools/test:run_and_gather_logs_main_lib",
43        ],
44        **kwargs
45    )
46
47# Create a benchmark test target of a TensorFlow python test (*py_tests)
48def tf_py_logged_benchmark(
49        name = None,
50        target = None,
51        benchmarks = "..",
52        tags = [],
53        **kwargs):
54    # For now generating a py benchmark is the same as generating a C++
55    # benchmark target. In the future this may change, so we have
56    # two macros just in case
57    tf_cc_logged_benchmark(
58        name = name,
59        target = target,
60        benchmarks = benchmarks,
61        tags = tags,
62        benchmark_type = "python_benchmark",
63        **kwargs
64    )
65