1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import functools 5import itertools 6import re 7from enum import auto, Enum 8from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple 9 10import sympy 11 12import torch.fx 13from torch._dynamo.utils import identity 14from torch.utils._sympy.symbol import SymT 15 16from . import config, dependencies 17from .codegen.common import index_prevent_reordering 18from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs 19from .virtualized import ops, V 20 21 22class InterpreterShim(torch.fx.Interpreter): 23 @staticmethod 24 @functools.lru_cache(None) 25 def _dummy_gm(): 26 return torch.fx.symbolic_trace(identity) 27 28 def __init__(self, graph, submodules): 29 # call super() with a placeholder to avoid constructing a 30 # GraphModule which is very expensive (it does codegen). 31 super().__init__(self._dummy_gm(), garbage_collect_values=False) 32 self.module = self # type: ignore[assignment] 33 self.graph = graph 34 self.submodules = submodules 35 self.extra_traceback = False 36 self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] 37 self.current_node = None 38 39 def run_node(self, n: torch.fx.Node) -> Any: 40 self.current_node = n 41 return super().run_node(n) 42 43 def run(self, *args, **kwargs): 44 with V.set_interpreter_handler(self): 45 return super().run(*args, **kwargs) 46 47 48class MemoryEntry(NamedTuple): 49 index_name: str # LoopBody.indexing_exprs[index_name] 50 buffer_name: Optional[str] 51 mode: Optional[str] # V.ops.store(..., mode=mode) 52 53 54class MemoryUsageType(Enum): 55 # These are 1:1 with the opcode generating the usage 56 LOAD = auto() 57 LOAD_SEED = auto() 58 STORE = auto() 59 STORE_REDUCTION = auto() 60 INDEX_EXPR = auto() 61 CHECK_BOUNDS = auto() 62 BUCKETIZE = auto() 63 64 65class LoopBody: 66 """ 67 Captures the body of a Loops subclass into an FX graph. Persists any 68 indexing simplifications and makes it easier to analyze loop bodies. 69 """ 70 71 indexing_exprs: Dict[str, sympy.Expr] 72 indexing_exprs_name: Dict[sympy.Expr, str] 73 submodules: Dict[str, Any] 74 subblocks: Dict[str, LoopBodyBlock] 75 indirect_vars: List[str] 76 indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] 77 root_block: LoopBodyBlock 78 memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] 79 80 def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): 81 super().__init__() 82 83 _flat_sizes = tuple(var_ranges.values()) 84 self.sizes = ( 85 _flat_sizes[: len(iter_vars)], 86 _flat_sizes[len(iter_vars) :], 87 ) 88 89 self.iter_vars = iter_vars 90 self.reduce_vars = reduce_vars 91 self.var_ranges = var_ranges 92 93 if isinstance(fn, LoopBody): 94 self._init_with_copy(fn, args) 95 else: 96 self._init_with_tracing(fn, args) 97 98 self.indexing = None 99 100 def _init_with_tracing(self, fn, args): 101 """Do an FX trace of an arbitrary callable to construct self""" 102 self.indexing_exprs = {} 103 self.indexing_exprs_name = {} 104 self.submodules = {"get_index": self.get_index} 105 self.subblocks = {} 106 self.indirect_vars = [] 107 self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} 108 self.memory_usage = {t: [] for t in MemoryUsageType} 109 self.root_block = LoopBodyBlock(self, fn, args) # traces 110 del self.indexing_exprs_name # not used after _init_with_tracing 111 112 def _init_with_copy(self, other: LoopBody, args): 113 """ 114 _init_with_tracing() is slow, so this is a fast path in the case 115 where we are just reordering/merging/splitting the args of an 116 existing LoopBody. 117 """ 118 indexing_exprs = other.indexing_from_args(args) 119 self.indexing_exprs = { 120 name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) 121 for name, expr in indexing_exprs.items() 122 } 123 self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} 124 self.indirect_vars = other.indirect_vars 125 self.indirect_var_ranges = other.indirect_var_ranges 126 self.memory_usage = other.memory_usage 127 self.root_block = other.root_block.clone(self) 128 129 submodules = {**other.submodules} 130 submodules.pop("get_index") 131 self.submodules = { 132 "get_index": self.get_index, 133 **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] 134 } 135 136 def merge_loops(self) -> LoopBody: 137 """ 138 Merge both iteration and reduction loops and return a new LoopBody. 139 """ 140 old_body = self 141 old_sizes = self.sizes 142 old_iter_vars, old_reduce_vars = old_body.vars 143 old_iter_sizes, old_reduce_sizes = old_sizes 144 145 index_exprs = [*old_body.indexing_exprs.values()] 146 147 iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( 148 old_iter_vars, 149 old_iter_sizes, 150 index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), 151 ) 152 153 reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( 154 old_reduce_vars, 155 old_reduce_sizes, 156 index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), 157 ) 158 159 # if iter_sizes == old_iter_sizes: 160 # # no dimensions get merged. 161 # return old_sizes, old_body 162 163 # Note: if no dimension get merges, the symbol prefix will 164 # remain 'y'. But if we merge dimensions, we change prefix to 165 # 'z'. If this is an issue, we can always retrace the LoopBody 166 # to change symbol prefix to 'z'. 167 # 168 # There is indeed an issue due to symbol name conflicting. 169 # y0 maybe reused for the y dimension later. 170 ( 171 iter_vars, 172 reduce_vars, 173 ), var_ranges = dependencies.index_vars_no_squeeze( 174 iter_sizes, reduce_sizes, prefix="t" 175 ) 176 new_body = LoopBody( 177 old_body, 178 [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], 179 var_ranges, 180 iter_vars, 181 reduce_vars, 182 ) 183 184 # use the original symbol prefix 185 # Can try to optimize if this is a bottleneck for compilation time 186 (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( 187 iter_sizes, reduce_sizes, prefix="z" 188 ) 189 new_body2 = LoopBody( 190 new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 191 ) 192 return new_body2 193 194 def reorder_iter_loops(self, new_order) -> LoopBody: 195 """ 196 Reorder iteration loops and return a new LoopBody. 197 """ 198 from .ir import same_reorder 199 200 old_body = self 201 old_sizes = self.sizes 202 assert len(old_sizes[0]) == len(new_order) 203 reorder_fn = same_reorder(new_order) 204 205 iter_size, reduce_size = old_sizes 206 new_iter_size = reorder_fn(iter_size) 207 208 new_sizes = (new_iter_size, reduce_size) 209 210 (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( 211 *new_sizes, prefix="t" # type: ignore[arg-type] 212 ) 213 214 inverse_order = {b: a for a, b in enumerate(new_order)} 215 inverse_order = [inverse_order[i] for i in range(len(new_order))] 216 217 def new_body(*indices: Sequence[sympy.Expr]) -> Any: 218 index = list(itertools.chain(*indices)) 219 assert len(index) == len(iter_size) + len(reduce_size) 220 iter_idx = index[: len(iter_size)] 221 reduce_idx = index[len(iter_size) :] 222 iter_idx = [iter_idx[i] for i in inverse_order] 223 return old_body(iter_idx, reduce_idx) 224 225 loop_body = LoopBody( 226 new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars 227 ) 228 229 # use the original symbol prefix so we can do multiple round of reordering 230 (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( 231 *new_sizes, prefix="z" # type: ignore[arg-type] 232 ) 233 new_body = LoopBody( 234 loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 235 ) 236 return new_body 237 238 @property 239 def vars(self): 240 assert self.iter_vars is not None 241 assert self.reduce_vars is not None 242 return self.iter_vars, self.reduce_vars 243 244 @cache_on_self 245 def get_nodes(self): 246 all_graphs = itertools.chain( 247 (self.root_block.graph,), 248 (block.graph for block in self.subblocks.values()), 249 ) 250 return [node for graph in all_graphs for node in graph.nodes] 251 252 @cache_on_self 253 def bounds(self): 254 # Doing a local import to avoid dumping all the code here 255 from .bounds import BoundVars 256 257 return BoundVars(self) 258 259 def get_read_expr(self, buffer_name): 260 # reversed to match old behavior 261 for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): 262 if entry.buffer_name == buffer_name: 263 return self.indexing_exprs[entry.index_name] 264 raise KeyError(buffer_name) 265 266 def get_write_expr(self, buffer_name): 267 for entry in itertools.chain( 268 self.memory_usage[MemoryUsageType.STORE], 269 self.memory_usage[MemoryUsageType.STORE_REDUCTION], 270 ): 271 if entry.buffer_name == buffer_name: 272 return self.indexing_exprs[entry.index_name] 273 raise KeyError(buffer_name) 274 275 def get_read_exprs(self): 276 return [ 277 self.indexing_exprs[entry.index_name] 278 for entry in self.memory_usage[MemoryUsageType.LOAD] 279 ] 280 281 def get_write_exprs(self): 282 return [ 283 self.indexing_exprs[entry.index_name] 284 for entry in itertools.chain( 285 self.memory_usage[MemoryUsageType.STORE], 286 self.memory_usage[MemoryUsageType.STORE_REDUCTION], 287 ) 288 ] 289 290 def debug_str(self): 291 lines = [f"var_ranges = {dict(self.var_ranges)}"] 292 lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) 293 lines.extend( 294 [ 295 block.debug_str(name) 296 for name, block in itertools.chain( 297 [("body", self.root_block)], self.subblocks.items() 298 ) 299 ] 300 ) 301 return "\n".join(lines) 302 303 def is_memory_copy(self) -> bool: 304 """ 305 True of this contains only a single loads and store. 306 Note, this could involve a layout change. 307 """ 308 return ( 309 len(self.memory_usage[MemoryUsageType.LOAD]) == 1 310 and len(self.memory_usage[MemoryUsageType.STORE]) == 1 311 and len(self.submodules) == 1 # get_index 312 and self.root_block.contains_only_ops(("load", "store")) 313 ) 314 315 __repr__ = debug_str 316 317 def add_index_expr( 318 self, 319 expr: sympy.Expr, 320 mtype: MemoryUsageType, 321 buffer_name: Optional[str] = None, 322 mode: Optional[str] = None, 323 ): 324 name = self.indexing_exprs_name.get(expr) 325 if not name: 326 name = f"index{len(self.indexing_exprs)}" 327 self.indexing_exprs_name[expr] = name 328 self.indexing_exprs[name] = expr 329 self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) 330 return name 331 332 def add_submodule(self, block, prefix): 333 """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" 334 if prefix[-1].isnumeric() and prefix not in self.submodules: 335 name = prefix 336 else: 337 name = f"{prefix}{len(self.submodules)}" 338 self.submodules[name] = block 339 return name 340 341 def add_indirect(self, size): 342 var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) 343 assert var not in self.indirect_var_ranges 344 self.indirect_vars.append(var) 345 self.indirect_var_ranges[var] = size 346 return var 347 348 def replace_indirect(self, old, new): 349 """Swap in a variable used in indirect indexing""" 350 if str(old) == str(new): 351 return 352 assert self.indexing is not None 353 self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} 354 355 def get_index(self, name): 356 assert self.indexing is not None 357 return self.indexing[name] 358 359 def indexing_from_args(self, indices): 360 index = [*itertools.chain.from_iterable(indices)] 361 assert len(index) == len(self.var_ranges), (index, self.var_ranges) 362 assert all( 363 v not in self.var_ranges for v in index 364 ), f"{self.var_ranges=}, {indices=}" 365 replacements = dict(zip(self.var_ranges.keys(), index)) 366 return { 367 name: sympy_subs(expr, replacements) 368 for name, expr in self.indexing_exprs.items() 369 } 370 371 def __call__(self, *indices): 372 self.indexing = self.indexing_from_args(indices) 373 result = self.root_block() 374 self.indexing = None 375 return result 376 377 def bind_set_indirect_shim(self, var, size, check, wrap_neg): 378 def set_indirect(new_var): 379 self.replace_indirect( 380 var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) 381 ) 382 383 set_indirect.clone = functools.partial( # type: ignore[attr-defined] 384 LoopBody.bind_set_indirect_shim, 385 var=var, 386 size=size, 387 check=check, 388 wrap_neg=wrap_neg, 389 ) 390 return set_indirect 391 392 def bind_scan_shim(self, combine_fn): 393 def shim(dtypes, values): 394 return V.ops.scan(dtypes, combine_fn, values) 395 396 shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] 397 return shim 398 399 def bind_masked_shim(self, name): 400 def shim(mask, other): 401 return V.ops.masked(mask, self.subblocks[name], other) 402 403 shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] 404 return shim 405 406 407class LoopBodyBlock: 408 """ 409 Captures the body of a Loops subclass into an FX graph. 410 In normal cases there will be a 1:1 mapping between LoopBody and 411 LoopBodyBlock, hower in the case of ops.masked() the masked out 412 operations will manifest as an extra LoopBodyBlock. 413 """ 414 415 def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): 416 self.body = body 417 418 def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): 419 return tracer.create_proxy( 420 "call_module", 421 "get_index", 422 (body.add_index_expr(expr, mtype, **kwargs),), 423 {}, 424 ) 425 426 class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] 427 self.name = "CaptureIndexing" 428 429 def load(self, name: str, index: sympy.Expr): 430 index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) 431 return self._inner.load(name, index) 432 433 def load_seed(self, name: str, index: int): 434 assert isinstance(index, int) 435 body.add_index_expr( 436 sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name 437 ) 438 return self._inner.load_seed(name, index) 439 440 def store(self, name, index, value, mode=None): 441 index = add_index( 442 index, MemoryUsageType.STORE, buffer_name=name, mode=mode 443 ) 444 return self._inner.store(name, index, value, mode) 445 446 def store_reduction(self, name, index, value): 447 index = add_index( 448 index, MemoryUsageType.STORE_REDUCTION, buffer_name=name 449 ) 450 return self._inner.store_reduction(name, index, value) 451 452 def reduction(self, dtype, src_dtype, reduction_type, value): 453 result = self._inner.reduction(dtype, src_dtype, reduction_type, value) 454 if "welford" in reduction_type: 455 return tuple(result[i] for i in range(3)) 456 return result 457 458 def index_expr(self, index, dtype): 459 if isinstance(index, (int, sympy.Integer)): 460 return self._inner.constant(int(index), dtype) 461 index = add_index(index, MemoryUsageType.INDEX_EXPR) 462 return self._inner.index_expr(index, dtype) 463 464 def check_bounds(self, index, size, lower, upper): 465 index = add_index(index, MemoryUsageType.CHECK_BOUNDS) 466 size = add_index(size, MemoryUsageType.CHECK_BOUNDS) 467 return self._inner.check_bounds(index, size, lower, upper) 468 469 def bucketize( 470 self, 471 values, 472 offsets_name: str, 473 offsets_size: sympy.Expr, 474 indexing_dtype: torch.dtype, 475 right: bool, 476 ): 477 offsets_size = add_index( 478 offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name 479 ) 480 return self._inner.bucketize( 481 values, offsets_name, offsets_size, indexing_dtype, right 482 ) 483 484 @staticmethod 485 def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): 486 """ 487 Recursively capture the masked out body in another LoopBodyBlock 488 """ 489 name = self.body.add_submodule(None, "masked_subblock") 490 self.body.submodules[name] = self.body.bind_masked_shim(name) 491 self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) 492 return tracer.create_proxy( 493 "call_module", name, (mask_proxy, other_proxy), {} 494 ) 495 496 @staticmethod 497 def scan( 498 dtype_proxy, 499 combine_fn: Callable[ 500 [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] 501 ], 502 value_proxy, 503 ): 504 shim = self.body.bind_scan_shim(combine_fn) 505 name = self.body.add_submodule(shim, "scan") 506 result = tracer.create_proxy( 507 "call_module", 508 name, 509 (dtype_proxy, value_proxy), 510 {}, 511 ) 512 # Proxies are iterable, but some methods expect tuples/lists 513 return tuple(result[i] for i in range(len(value_proxy))) 514 515 def sort(self, dtypes, values, stable, descending): 516 result = self._inner.sort(dtypes, values, stable, descending) 517 # Proxies are iterable, but some methods expect tuples/lists 518 return tuple(result[i] for i in range(len(values))) 519 520 def frexp(self, value_proxy): 521 result = self._inner.frexp(value_proxy) 522 # Proxies are iterable, but some methods expect tuples/lists 523 return (result[0], result[1]) 524 525 @staticmethod 526 def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): 527 """ 528 Flow data from tensors into indexing formulas. 529 Introduce a call_module to update the indexing. 530 """ 531 532 var = self.body.add_indirect(size) 533 set_indirect = self.body.bind_set_indirect_shim( 534 var, size, check, wrap_neg 535 ) 536 tracer.create_proxy( 537 "call_module", 538 self.body.add_submodule(set_indirect, f"set_{var}"), 539 (index_proxy,), 540 {}, 541 ) 542 return var 543 544 @staticmethod 545 def output(result): 546 tracer.create_proxy("output", "output", (result,), {}) 547 548 tracer = torch.fx.Tracer() 549 tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) 550 proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) 551 552 from .index_propagation import IndexPropagation 553 from .sizevars import SimplifyIndexing 554 555 handler: Any = SimplifyIndexing( 556 CaptureIndexing(proxy_ops), self.body.var_ranges 557 ) 558 if config.constant_and_index_propagation: 559 handler = IndexPropagation( 560 handler, self.body.var_ranges, self.body.indirect_var_ranges 561 ) 562 563 with V.set_ops_handler(handler): 564 # This indirection is just a cute way to get IndexPropagation to 565 # unwrap the return value. 566 ops.output(fn(*args)) 567 self.graph = tracer.graph 568 569 def __call__(self): 570 graph = self.graph 571 submodules = self.body.submodules 572 573 return InterpreterShim(graph, submodules).run(V.get_ops_handler()) 574 575 def debug_str(self, name="block"): 576 code = torch.fx.GraphModule(self.body.submodules, self.graph).code 577 return re.sub( 578 # strip `; del var0` suffixes to make output prettier 579 r";[^\n]*", 580 "", 581 code.strip().replace("def forward(", f"def {name}("), 582 ) 583 584 def contains_only_ops(self, allowed_ops) -> bool: 585 return all( 586 node.target in allowed_ops 587 for node in self.graph.find_nodes(op="call_method") 588 ) 589 590 def clone(self, body: LoopBody): 591 """Shallow copy with a new parent LoopBody""" 592 copy = LoopBodyBlock.__new__(LoopBodyBlock) 593 copy.__dict__.update({**self.__dict__, "body": body}) 594 return copy 595