1# mypy: allow-untyped-defs 2import logging 3import os 4import pathlib 5from typing import Any, List 6 7from torch._inductor.metrics import get_metric_table, is_metric_table_enabled 8from torch.utils._ordered_set import OrderedSet 9 10from .. import config 11from ..codecache import get_path, TritonFuture 12from ..runtime.benchmarking import benchmarker 13from ..utils import cache_on_self, IndentedBuffer 14from ..virtualized import V 15from .common import TensorArg 16 17 18log = logging.getLogger(__name__) 19 20 21def get_kernel_argdefs(kernel): 22 arg_defs, _, _, _ = kernel.args.python_argdefs() 23 return arg_defs 24 25 26def _get_all_args(args_list, arg_types_list=None): 27 all_args = max(args_list, key=len)[:] 28 arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None 29 for args in args_list: 30 assert set(args).issubset(set(all_args)), f"{args} v.s. {all_args}" 31 32 return all_args, arg_types 33 34 35def get_all_kernel_argdefs(kernels): 36 """ 37 The logic here must match with `get_all_call_args`, except no need to get arg_types here 38 """ 39 argdefs_list = [get_kernel_argdefs(kernel) for kernel in kernels] 40 41 return _get_all_args(argdefs_list)[0] 42 43 44def get_all_call_args(call_args_list, arg_types_list): 45 """ 46 Passed in the call_args for each subkernel and return the call_args for the 47 combined multi-kernel. 48 49 Note an algorithm as follows does not always work: 50 ``` 51 all_call_args: Dict[ 52 Any, None 53 ] = {} # use a dict rather than set to maintain insertion order 54 for call_args in call_args_list: 55 all_call_args.update({arg: None for arg in call_args}) 56 57 all_call_args = list(all_call_args.keys()) 58 ``` 59 It will fail if any kernel has the same argument passed in multiple times. 60 Check test_pass_same_arg_multi_times in test_multi_kernel.py 61 62 Instead, we pick the longest call args and assert that other call args are 63 a subset of it. 64 """ 65 return _get_all_args(call_args_list, arg_types_list) 66 67 68def get_numel_argdefs(kernel): 69 numel_argdefs = [] 70 for tree in kernel.range_trees: 71 if tree.prefix != "r" or kernel.inside_reduction: 72 numel_argdefs.append(f"{tree.prefix}numel") 73 74 return numel_argdefs 75 76 77class MultiKernelState: 78 """ 79 Maintain state of multi-kernel compilation so we don't define duplicated 80 multi-kernel for the same set of sub-kernels. 81 82 V.graph.wrapper_code has a reference to MultiKernelState instance. 83 """ 84 85 def __init__(self): 86 self.subkernel_to_kernel_name = {} 87 88 def define_kernel(self, kernels): 89 """ 90 Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". 91 This has some minor issue. 92 93 E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , 94 there are 2 flavors of non-persistent reduction: 95 https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 96 and 97 https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd 98 99 The only different is cache eviction policy. 100 101 We should name the multi-kernel differently in these 2 cases. 102 """ 103 kernel_names = tuple(k.kernel_name for k in kernels) 104 if kernel_names in self.subkernel_to_kernel_name: 105 return self.subkernel_to_kernel_name[kernel_names] 106 107 # name the multi kernel based on the first kernel 108 multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" 109 self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name 110 111 if V.graph.cpp_wrapper: 112 # we should not generate any python code for multi-kernel during 113 # the second pass of cpp-wrapper. 114 return multi_kernel_name 115 116 buf = IndentedBuffer() 117 buf.writeline( 118 f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" 119 ) 120 with buf.indent(): 121 for name in kernel_names: 122 buf.writeline(f"{name},") 123 buf.writeline("])") 124 125 wrapper = V.graph.wrapper_code 126 wrapper.header.splice(buf) 127 if config.triton.autotune_at_compile_time: 128 wrapper.kernel_autotune_defs.splice(buf) 129 130 return multi_kernel_name 131 132 133class MultiKernel: 134 """ 135 This class maintains the compile time state for multi kernels. 136 137 Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. 138 The generated definition for the multi-kernel will looks like: 139 ``` 140 multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) 141 ``` 142 143 Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 144 """ 145 146 def __init__(self, kernels): 147 assert len(kernels) >= 2 148 149 self.kernels = kernels 150 self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( 151 kernels 152 ) 153 154 # need this since some code in inductor check if the kernel object has an args 155 # attribute to decide if it's a non-null kernel. 156 self.args = object() 157 158 def call_kernel(self, kernel_name): 159 """ 160 Collect the union of arguments from all subkernels as the arguments 161 for the multi-kernel. 162 """ 163 assert kernel_name == self.kernel_name 164 V.graph.wrapper_code.write_triton_header_once() 165 _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() 166 for kernel in self.kernels[1:]: 167 _, other_call_args, _, other_arg_types = kernel.args.python_argdefs() 168 assert call_args == other_call_args 169 assert arg_types == other_arg_types 170 171 grid: List[Any] = [] 172 173 if V.graph.cpp_wrapper: 174 # for the second pass of cpp-wrapper codegen, we should call 175 # the fast kernel directly 176 picked_kernel = MultiKernelCall.lookup_choice(kernel_name) 177 kernel_name = self.kernels[picked_kernel].kernel_name 178 179 # numels for all subkernels should be the same. Use kernels[0] here 180 self.kernels[0].add_numel_to_call_args_and_grid( 181 kernel_name, call_args, arg_types, grid 182 ) 183 184 grid = V.graph.wrapper_code.generate_default_grid(kernel_name, grid) 185 V.graph.wrapper_code.generate_kernel_call( 186 kernel_name, 187 call_args, 188 grid, 189 arg_types=arg_types, 190 ) 191 192 def codegen_nan_check(self): 193 wrapper = V.graph.wrapper_code 194 seen = set() 195 for k in self.kernels: 196 _, call_args, precompile_args, _ = k.args.python_argdefs() 197 for arg, precompile_arg in zip(call_args, precompile_args): 198 if arg in seen: 199 continue 200 seen.add(arg) 201 if isinstance(precompile_arg, TensorArg): 202 line = f"assert not {arg}.isnan().any().item()" 203 wrapper.writeline(line) 204 line = f"assert not {arg}.isinf().any().item()" 205 wrapper.writeline(line) 206 207 @property 208 def removed_buffers(self): 209 return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels]) 210 211 @property 212 def inplaced_to_remove(self): 213 return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels]) 214 215 @property 216 @cache_on_self 217 def inplace_update_buffers(self): 218 """ 219 Make sure all kernels have the same inplace update mappings. 220 """ 221 for k in self.kernels[1:]: 222 assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers 223 return self.kernels[0].inplace_update_buffers 224 225 def warn_mix_layout(self, kernel_name: str): 226 pass 227 228 229class MultiKernelCall: 230 """ 231 This class is called at run time to actually run the kernel 232 """ 233 234 def __init__(self, multi_kernel_name, kernels): 235 assert len(kernels) >= 2 236 self._kernels = kernels 237 self.multi_kernel_name = multi_kernel_name 238 239 self.disable_cache = os.environ.get( 240 "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" 241 ) == "1" or is_metric_table_enabled("persistent_red_perf") 242 243 self.picked_kernel = None 244 if config.triton.multi_kernel > 1: 245 # manually force a subkernel to ease perf testing 246 picked_by_config = config.triton.multi_kernel - 2 247 assert picked_by_config < len(self._kernels) 248 self.picked_kernel = picked_by_config 249 elif not self.disable_cache: 250 self.load_cache() 251 252 self._recorded = False 253 254 def cache_file_path(self): 255 _, _, path = get_path(self.kernels[0].fn.cache_key, "picked_kernel") 256 return pathlib.Path(path) 257 258 def load_cache(self): 259 assert self.picked_kernel is None 260 path = self.cache_file_path() 261 if path.exists(): 262 with path.open() as fd: 263 self.picked_kernel = int(fd.read()) 264 assert self.picked_kernel >= 0 and self.picked_kernel < len( 265 self._kernels 266 ) 267 log.debug( 268 "Load picked kernel %d from cache file %s", self.picked_kernel, path 269 ) 270 271 def store_cache(self): 272 assert self.picked_kernel is not None 273 path = self.cache_file_path() 274 path.parent.mkdir(parents=True, exist_ok=True) 275 276 with path.open("w") as fd: 277 fd.write(str(self.picked_kernel)) 278 log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) 279 280 @property 281 def kernels(self): 282 """ 283 Read results from future. 284 285 This should be called after parallel compilation is done. 286 In case you call this before compilation is done, 287 it may slow down the parallel compilation. 288 """ 289 for i, kernel in enumerate(self._kernels): 290 if isinstance(kernel, TritonFuture): 291 self._kernels[i] = kernel.result() 292 293 return self._kernels 294 295 def benchmark_sub_kernels(self, *args, **kwargs): 296 """ 297 Benchmark all the sub kernels and return the execution time 298 (in milliseconds) for each of time. 299 300 Unit test may mock this method to force a specific kernel to 301 be picked. 302 """ 303 304 def wrap_fn(kernel): 305 def inner(): 306 args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs) 307 return kernel.run(*args_clone, **kwargs_clone) 308 309 return inner 310 311 return [ 312 benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True) 313 for kernel in self.kernels 314 ] 315 316 # record_choice and lookup_choice are helper functions for cpp-wrapper 317 # codegen. The first pass use record_choice to keep the choice and 318 # the second pass do lookup by calling lookup_choice. 319 # 320 # An alternative that reused the multi-kernel cache does not work well 321 # since during codegen of the second pass, it's very hard to know the 322 # path for the cache file. Also reading the cache file need do some IO 323 # which can be slower. 324 @staticmethod 325 def record_choice(multi_kernel_name, choice): 326 """ 327 Record the multi-kernel choice for cpp-wrapper first pass codegen 328 for the second pass. 329 330 We should do nothing if this function is not called during codegen. 331 """ 332 from torch._inductor.graph import GraphLowering 333 334 if not isinstance(V.graph, GraphLowering): 335 return 336 337 if not V.graph.record_multi_kernel_choice: 338 return 339 340 V.graph.multi_kernel_to_choice[multi_kernel_name] = choice 341 342 @staticmethod 343 def lookup_choice(multi_kernel_name): 344 # this should always been done during cpp-wrapper codegen 345 assert V.graph.record_multi_kernel_choice 346 # there should be no miss 347 return V.graph.multi_kernel_to_choice[multi_kernel_name] 348 349 def run(self, *args, **kwargs): 350 if self.picked_kernel is None: 351 timings = self.benchmark_sub_kernels(*args, **kwargs) 352 self.picked_kernel = timings.index(min(timings)) 353 k0 = self.kernels[0] 354 log.debug( 355 "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", 356 self.picked_kernel, 357 [k.inductor_meta.get("kernel_name") for k in self.kernels], 358 k0.size_hints, 359 k0.inductor_meta.get("reduction_hint"), 360 timings, 361 ) 362 363 def get_kernel_path(k): 364 return k.fn.fn.__code__.co_filename 365 366 get_metric_table("persistent_red_perf").add_row( 367 lambda: { 368 "kernel1_name": get_kernel_path(self.kernels[0]), 369 "kernel2_name": get_kernel_path(self.kernels[1]), 370 "kernel1_latency": timings[0], 371 "kernel2_latency": timings[1], 372 "size_hints": k0.size_hints, 373 "reduction_hint": k0.inductor_meta.get("reduction_hint"), 374 "speedup": timings[1] / timings[0], 375 } 376 ) 377 378 if not self.disable_cache: 379 self.store_cache() 380 381 if not self._recorded: 382 self._recorded = True 383 self.record_choice(self.multi_kernel_name, self.picked_kernel) 384 self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] 385 self.run(*args, **kwargs) 386