1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5import functools 6import getpass 7import operator 8import os 9import re 10import tempfile 11 12import torch 13 14 15def conditional_product(*args): 16 return functools.reduce(operator.mul, [x for x in args if x]) 17 18 19def ceildiv(numer: int, denom: int) -> int: 20 return -(numer // -denom) 21 22 23def is_power_of_2(n: int) -> bool: 24 """Returns whether n = 2 ** m for some integer m.""" 25 return n > 0 and n & n - 1 == 0 26 27 28def next_power_of_2(n: int) -> int: 29 """Return the smallest power of 2 greater than or equal to n""" 30 n -= 1 31 n |= n >> 1 32 n |= n >> 2 33 n |= n >> 4 34 n |= n >> 8 35 n |= n >> 16 36 n |= n >> 32 37 n += 1 38 return n 39 40 41def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: 42 """ 43 Return the total number of bytes the arguments of tensor type takes. 44 45 For in/out args, tensor sizes are counted twice: once for reading and 46 once for writing. 47 48 The first num_in_out_args arguments are in out tensors. 49 """ 50 return sum( 51 arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) 52 for i, arg in enumerate(args) 53 if isinstance(arg, torch.Tensor) 54 ) 55 56 57def triton_config_to_hashable(cfg): 58 """ 59 Convert triton config to a tuple that can uniquely identify it. We can use 60 the return value as a dictionary key. 61 """ 62 items = sorted(cfg.kwargs.items()) 63 items.append(("num_warps", cfg.num_warps)) 64 items.append(("num_stages", cfg.num_stages)) 65 return tuple(items) 66 67 68def validate_triton_config(cfg): 69 # [Note: Triton pre_hook in inductor] 70 # pre-hook is a lambda function, which we don't attempt to serialize. 71 # right now, if a pre-hook is attached to the config, it will not be saved; 72 # and then it won't be used when the config is loaded from cache. 73 # So we assert - if we do get a pre_hook, it might get ignored after caching. 74 assert ( 75 getattr(cfg, "pre_hook", None) is None 76 ), "triton configs with pre_hooks not supported" 77 78 79def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): 80 info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" 81 slow = ms > 0.012 and gb_per_s < 650 82 return red_text(info_str) if color and slow else info_str 83 84 85def get_max_y_grid(): 86 return 65535 87 88 89def cache_dir() -> str: 90 cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") 91 if cache_dir is None: 92 os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() 93 os.makedirs(cache_dir, exist_ok=True) 94 return cache_dir 95 96 97def default_cache_dir(): 98 sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) 99 return os.path.join( 100 tempfile.gettempdir(), 101 "torchinductor_" + sanitized_username, 102 ) 103 104 105try: 106 import colorama 107 108 HAS_COLORAMA = True 109except ModuleNotFoundError: 110 HAS_COLORAMA = False 111 colorama = None # type: ignore[assignment] 112 113 114def _color_text(msg, color): 115 if not HAS_COLORAMA: 116 return msg 117 118 return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET 119 120 121def green_text(msg): 122 return _color_text(msg, "green") 123 124 125def yellow_text(msg): 126 return _color_text(msg, "yellow") 127 128 129def red_text(msg): 130 return _color_text(msg, "red") 131 132 133def blue_text(msg): 134 return _color_text(msg, "blue") 135 136 137def get_first_attr(obj, *attrs): 138 """ 139 Return the first available attribute or throw an exception if none is present. 140 """ 141 for attr in attrs: 142 if hasattr(obj, attr): 143 return getattr(obj, attr) 144 145 raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") 146 147 148try: 149 dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] 150except AttributeError: # Compile workers only have a mock version of torch 151 152 @contextlib.contextmanager 153 def dynamo_timed(key, phase_name=None, fwd_only=True): 154 yield 155