xref: /aosp_15_r20/external/pytorch/test/inductor/test_metrics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import torch
3from torch._inductor import config, metrics
4from torch._inductor.test_case import run_tests, TestCase
5from torch._inductor.utils import collect_defined_kernels
6from torch._inductor.wrapper_benchmark import get_kernel_category_by_source_code
7from torch.testing._internal.common_device_type import largeTensorTest
8from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
9
10
11example_kernel = """
12@triton_heuristics.reduction(
13    size_hints=[1024, 2048],
14    reduction_hint=ReductionHint.INNER,
15    filename=__file__,
16    triton_meta={
17        'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'},
18        'device': 0,
19        'device_type': 'GPU_TYPE',
20        'constants': {},
21        'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2, 3))]},
22    inductor_meta={
23        'autotune_hints': set(),
24        'kernel_name': 'triton_red_fused_add_sum_2',
25        'mutated_arg_names': ['in_out_ptr0'],
26        'no_x_dim': False,
27        'kernel_num_gb': 0.0083968
28    }
29)
30@triton.jit
31def triton_red_fused_add_sum_2(in_out_ptr0, in_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
32    xnumel = 1024
33    rnumel = 2048
34    xoffset = tl.program_id(0) * XBLOCK
35    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
36    xmask = xindex < xnumel
37    rbase = tl.arange(0, RBLOCK)[None, :]
38    x0 = xindex
39    _tmp2 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
40    for roffset in range(0, rnumel, RBLOCK):
41        rindex = roffset + rbase
42        rmask = rindex < rnumel
43        r1 = rindex
44        tmp0 = tl.load(in_ptr0 + (r1 + (2048*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)
45        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
46        tmp3 = _tmp2 + tmp1
47        _tmp2 = tl.where(rmask & xmask, tmp3, _tmp2)
48    tmp2 = tl.sum(_tmp2, 1)[:, None]
49    tmp4 = tl.load(in_out_ptr0 + (x0), xmask, eviction_policy='evict_last')
50    tmp5 = tmp4 + tmp2
51    tl.debug_barrier()
52    tl.store(in_out_ptr0 + (x0), tmp5, xmask)
53""".replace(
54    "GPU_TYPE", GPU_TYPE
55)
56
57
58class TestMetrics(TestCase):
59    def test_parse_proper_kernel_fn_code(self):
60        proper_kernel_fn_code = metrics._parse_proper_kernel_fn_code(example_kernel)
61        assert proper_kernel_fn_code.startswith("def ")
62
63    def test_count_args(self):
64        proper_kernel_fn_code = metrics._parse_proper_kernel_fn_code(example_kernel)
65        self.assertEqual(6, metrics._count_args(proper_kernel_fn_code))
66
67    def test_count_pattern(self):
68        proper_kernel_fn_code = metrics._parse_proper_kernel_fn_code(example_kernel)
69        self.assertEqual(2, metrics._count_pattern(proper_kernel_fn_code, "tl.load"))
70        self.assertEqual(1, metrics._count_pattern(proper_kernel_fn_code, "tl.store"))
71        self.assertEqual(1, metrics._count_pattern(proper_kernel_fn_code, "for "))
72
73    def test_parse_reduction_hint(self):
74        kernel_category = get_kernel_category_by_source_code(example_kernel)
75        self.assertEqual("reduction", kernel_category)
76        self.assertEqual(
77            "INNER", metrics._parse_reduction_hint(kernel_category, example_kernel)
78        )
79
80    @config.patch("fx_graph_remote_cache", False)
81    def test_atomic_add(self):
82        @torch.compile
83        def f(lhs, index, rhs):
84            return lhs.index_put_([index], rhs, accumulate=True)
85
86        lhs = torch.randn(1024, device=GPU_TYPE)
87        index = torch.randint(0, 1024, [32], device=GPU_TYPE, dtype=torch.int32)
88        rhs = torch.randn(32, device=GPU_TYPE)
89
90        kernel_list = []
91        with collect_defined_kernels(kernel_list):
92            f(lhs, index, rhs)
93
94        self.assertEqual(len(kernel_list), 1)
95        kernel_code = kernel_list[0]
96        self.assertEqual(metrics._count_pattern(kernel_code, "tl.atomic_add"), 1)
97
98    @largeTensorTest(25e7 * 2 * 4, device=GPU_TYPE)
99    @config.patch("fx_graph_remote_cache", False)
100    @config.patch("benchmark_kernel", True)
101    def test_kernel_args_num_gb(self):
102        @torch.compile
103        def f(x):
104            return x + 1
105
106        x = torch.randn(int(25e7), device=GPU_TYPE)
107        kernel_list = []
108        with collect_defined_kernels(kernel_list):
109            f(x)
110
111        self.assertEqual(len(kernel_list), 1)
112        kernel_code = kernel_list[0]
113        self.assertEqual(
114            metrics._parse_kernel_args_num_gb(kernel_code, "pointwise"), 2.0
115        )
116
117
118if __name__ == "__main__":
119    if HAS_GPU:
120        run_tests()
121