xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/runtime_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import contextlib
5import functools
6import getpass
7import operator
8import os
9import re
10import tempfile
11
12import torch
13
14
15def conditional_product(*args):
16    return functools.reduce(operator.mul, [x for x in args if x])
17
18
19def ceildiv(numer: int, denom: int) -> int:
20    return -(numer // -denom)
21
22
23def is_power_of_2(n: int) -> bool:
24    """Returns whether n = 2 ** m for some integer m."""
25    return n > 0 and n & n - 1 == 0
26
27
28def next_power_of_2(n: int) -> int:
29    """Return the smallest power of 2 greater than or equal to n"""
30    n -= 1
31    n |= n >> 1
32    n |= n >> 2
33    n |= n >> 4
34    n |= n >> 8
35    n |= n >> 16
36    n |= n >> 32
37    n += 1
38    return n
39
40
41def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
42    """
43    Return the total number of bytes the arguments of tensor type takes.
44
45    For in/out args, tensor sizes are counted twice: once for reading and
46    once for writing.
47
48    The first num_in_out_args arguments are in out tensors.
49    """
50    return sum(
51        arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
52        for i, arg in enumerate(args)
53        if isinstance(arg, torch.Tensor)
54    )
55
56
57def triton_config_to_hashable(cfg):
58    """
59    Convert triton config to a tuple that can uniquely identify it. We can use
60    the return value as a dictionary key.
61    """
62    items = sorted(cfg.kwargs.items())
63    items.append(("num_warps", cfg.num_warps))
64    items.append(("num_stages", cfg.num_stages))
65    return tuple(items)
66
67
68def validate_triton_config(cfg):
69    # [Note: Triton pre_hook in inductor]
70    # pre-hook is a lambda function, which we don't attempt to serialize.
71    # right now, if a pre-hook is attached to the config, it will not be saved;
72    # and then it won't be used when the config is loaded from cache.
73    # So we assert - if we do get a pre_hook, it might get ignored after caching.
74    assert (
75        getattr(cfg, "pre_hook", None) is None
76    ), "triton configs with pre_hooks not supported"
77
78
79def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
80    info_str = f"{prefix}{ms:.3f}ms    \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
81    slow = ms > 0.012 and gb_per_s < 650
82    return red_text(info_str) if color and slow else info_str
83
84
85def get_max_y_grid():
86    return 65535
87
88
89def cache_dir() -> str:
90    cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
91    if cache_dir is None:
92        os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir()
93    os.makedirs(cache_dir, exist_ok=True)
94    return cache_dir
95
96
97def default_cache_dir():
98    sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser())
99    return os.path.join(
100        tempfile.gettempdir(),
101        "torchinductor_" + sanitized_username,
102    )
103
104
105try:
106    import colorama
107
108    HAS_COLORAMA = True
109except ModuleNotFoundError:
110    HAS_COLORAMA = False
111    colorama = None  # type: ignore[assignment]
112
113
114def _color_text(msg, color):
115    if not HAS_COLORAMA:
116        return msg
117
118    return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
119
120
121def green_text(msg):
122    return _color_text(msg, "green")
123
124
125def yellow_text(msg):
126    return _color_text(msg, "yellow")
127
128
129def red_text(msg):
130    return _color_text(msg, "red")
131
132
133def blue_text(msg):
134    return _color_text(msg, "blue")
135
136
137def get_first_attr(obj, *attrs):
138    """
139    Return the first available attribute or throw an exception if none is present.
140    """
141    for attr in attrs:
142        if hasattr(obj, attr):
143            return getattr(obj, attr)
144
145    raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
146
147
148try:
149    dynamo_timed = torch._dynamo.utils.dynamo_timed  # type: ignore[has-type]
150except AttributeError:  # Compile workers only have a mock version of torch
151
152    @contextlib.contextmanager
153    def dynamo_timed(key, phase_name=None, fwd_only=True):
154        yield
155