xref: /aosp_15_r20/external/pytorch/torch/_inductor/loop_body.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import functools
5import itertools
6import re
7from enum import auto, Enum
8from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
9
10import sympy
11
12import torch.fx
13from torch._dynamo.utils import identity
14from torch.utils._sympy.symbol import SymT
15
16from . import config, dependencies
17from .codegen.common import index_prevent_reordering
18from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
19from .virtualized import ops, V
20
21
22class InterpreterShim(torch.fx.Interpreter):
23    @staticmethod
24    @functools.lru_cache(None)
25    def _dummy_gm():
26        return torch.fx.symbolic_trace(identity)
27
28    def __init__(self, graph, submodules):
29        # call super() with a placeholder to avoid constructing a
30        # GraphModule which is very expensive (it does codegen).
31        super().__init__(self._dummy_gm(), garbage_collect_values=False)
32        self.module = self  # type: ignore[assignment]
33        self.graph = graph
34        self.submodules = submodules
35        self.extra_traceback = False
36        self.fetch_attr = submodules.__getitem__  # type: ignore[method-assign]
37        self.current_node = None
38
39    def run_node(self, n: torch.fx.Node) -> Any:
40        self.current_node = n
41        return super().run_node(n)
42
43    def run(self, *args, **kwargs):
44        with V.set_interpreter_handler(self):
45            return super().run(*args, **kwargs)
46
47
48class MemoryEntry(NamedTuple):
49    index_name: str  # LoopBody.indexing_exprs[index_name]
50    buffer_name: Optional[str]
51    mode: Optional[str]  # V.ops.store(..., mode=mode)
52
53
54class MemoryUsageType(Enum):
55    # These are 1:1 with the opcode generating the usage
56    LOAD = auto()
57    LOAD_SEED = auto()
58    STORE = auto()
59    STORE_REDUCTION = auto()
60    INDEX_EXPR = auto()
61    CHECK_BOUNDS = auto()
62    BUCKETIZE = auto()
63
64
65class LoopBody:
66    """
67    Captures the body of a Loops subclass into an FX graph.  Persists any
68    indexing simplifications and makes it easier to analyze loop bodies.
69    """
70
71    indexing_exprs: Dict[str, sympy.Expr]
72    indexing_exprs_name: Dict[sympy.Expr, str]
73    submodules: Dict[str, Any]
74    subblocks: Dict[str, LoopBodyBlock]
75    indirect_vars: List[str]
76    indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
77    root_block: LoopBodyBlock
78    memory_usage: Dict[MemoryUsageType, List[MemoryEntry]]
79
80    def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
81        super().__init__()
82
83        _flat_sizes = tuple(var_ranges.values())
84        self.sizes = (
85            _flat_sizes[: len(iter_vars)],
86            _flat_sizes[len(iter_vars) :],
87        )
88
89        self.iter_vars = iter_vars
90        self.reduce_vars = reduce_vars
91        self.var_ranges = var_ranges
92
93        if isinstance(fn, LoopBody):
94            self._init_with_copy(fn, args)
95        else:
96            self._init_with_tracing(fn, args)
97
98        self.indexing = None
99
100    def _init_with_tracing(self, fn, args):
101        """Do an FX trace of an arbitrary callable to construct self"""
102        self.indexing_exprs = {}
103        self.indexing_exprs_name = {}
104        self.submodules = {"get_index": self.get_index}
105        self.subblocks = {}
106        self.indirect_vars = []
107        self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
108        self.memory_usage = {t: [] for t in MemoryUsageType}
109        self.root_block = LoopBodyBlock(self, fn, args)  # traces
110        del self.indexing_exprs_name  # not used after _init_with_tracing
111
112    def _init_with_copy(self, other: LoopBody, args):
113        """
114        _init_with_tracing() is slow, so this is a fast path in the case
115        where we are just reordering/merging/splitting the args of an
116        existing LoopBody.
117        """
118        indexing_exprs = other.indexing_from_args(args)
119        self.indexing_exprs = {
120            name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges)
121            for name, expr in indexing_exprs.items()
122        }
123        self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()}
124        self.indirect_vars = other.indirect_vars
125        self.indirect_var_ranges = other.indirect_var_ranges
126        self.memory_usage = other.memory_usage
127        self.root_block = other.root_block.clone(self)
128
129        submodules = {**other.submodules}
130        submodules.pop("get_index")
131        self.submodules = {
132            "get_index": self.get_index,
133            **{k: v.clone(self) for k, v in submodules.items()},  # type: ignore[attr-defined]
134        }
135
136    def merge_loops(self) -> LoopBody:
137        """
138        Merge both iteration and reduction loops and return a new LoopBody.
139        """
140        old_body = self
141        old_sizes = self.sizes
142        old_iter_vars, old_reduce_vars = old_body.vars
143        old_iter_sizes, old_reduce_sizes = old_sizes
144
145        index_exprs = [*old_body.indexing_exprs.values()]
146
147        iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops(
148            old_iter_vars,
149            old_iter_sizes,
150            index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes),
151        )
152
153        reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops(
154            old_reduce_vars,
155            old_reduce_sizes,
156            index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes),
157        )
158
159        # if iter_sizes == old_iter_sizes:
160        #     # no dimensions get merged.
161        #     return old_sizes, old_body
162
163        # Note: if no dimension get merges, the symbol prefix will
164        # remain 'y'. But if we merge dimensions, we change prefix to
165        # 'z'. If this is an issue, we can always retrace the LoopBody
166        # to change symbol prefix to 'z'.
167        #
168        # There is indeed an issue due to symbol name conflicting.
169        # y0 maybe reused for the y dimension later.
170        (
171            iter_vars,
172            reduce_vars,
173        ), var_ranges = dependencies.index_vars_no_squeeze(
174            iter_sizes, reduce_sizes, prefix="t"
175        )
176        new_body = LoopBody(
177            old_body,
178            [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
179            var_ranges,
180            iter_vars,
181            reduce_vars,
182        )
183
184        # use the original symbol prefix
185        # Can try to optimize if this is a bottleneck for compilation time
186        (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
187            iter_sizes, reduce_sizes, prefix="z"
188        )
189        new_body2 = LoopBody(
190            new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
191        )
192        return new_body2
193
194    def reorder_iter_loops(self, new_order) -> LoopBody:
195        """
196        Reorder iteration loops and return a new LoopBody.
197        """
198        from .ir import same_reorder
199
200        old_body = self
201        old_sizes = self.sizes
202        assert len(old_sizes[0]) == len(new_order)
203        reorder_fn = same_reorder(new_order)
204
205        iter_size, reduce_size = old_sizes
206        new_iter_size = reorder_fn(iter_size)
207
208        new_sizes = (new_iter_size, reduce_size)
209
210        (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
211            *new_sizes, prefix="t"  # type: ignore[arg-type]
212        )
213
214        inverse_order = {b: a for a, b in enumerate(new_order)}
215        inverse_order = [inverse_order[i] for i in range(len(new_order))]
216
217        def new_body(*indices: Sequence[sympy.Expr]) -> Any:
218            index = list(itertools.chain(*indices))
219            assert len(index) == len(iter_size) + len(reduce_size)
220            iter_idx = index[: len(iter_size)]
221            reduce_idx = index[len(iter_size) :]
222            iter_idx = [iter_idx[i] for i in inverse_order]
223            return old_body(iter_idx, reduce_idx)
224
225        loop_body = LoopBody(
226            new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars
227        )
228
229        # use the original symbol prefix so we can do multiple round of reordering
230        (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
231            *new_sizes, prefix="z"  # type: ignore[arg-type]
232        )
233        new_body = LoopBody(
234            loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
235        )
236        return new_body
237
238    @property
239    def vars(self):
240        assert self.iter_vars is not None
241        assert self.reduce_vars is not None
242        return self.iter_vars, self.reduce_vars
243
244    @cache_on_self
245    def get_nodes(self):
246        all_graphs = itertools.chain(
247            (self.root_block.graph,),
248            (block.graph for block in self.subblocks.values()),
249        )
250        return [node for graph in all_graphs for node in graph.nodes]
251
252    @cache_on_self
253    def bounds(self):
254        # Doing a local import to avoid dumping all the code here
255        from .bounds import BoundVars
256
257        return BoundVars(self)
258
259    def get_read_expr(self, buffer_name):
260        # reversed to match old behavior
261        for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]):
262            if entry.buffer_name == buffer_name:
263                return self.indexing_exprs[entry.index_name]
264        raise KeyError(buffer_name)
265
266    def get_write_expr(self, buffer_name):
267        for entry in itertools.chain(
268            self.memory_usage[MemoryUsageType.STORE],
269            self.memory_usage[MemoryUsageType.STORE_REDUCTION],
270        ):
271            if entry.buffer_name == buffer_name:
272                return self.indexing_exprs[entry.index_name]
273        raise KeyError(buffer_name)
274
275    def get_read_exprs(self):
276        return [
277            self.indexing_exprs[entry.index_name]
278            for entry in self.memory_usage[MemoryUsageType.LOAD]
279        ]
280
281    def get_write_exprs(self):
282        return [
283            self.indexing_exprs[entry.index_name]
284            for entry in itertools.chain(
285                self.memory_usage[MemoryUsageType.STORE],
286                self.memory_usage[MemoryUsageType.STORE_REDUCTION],
287            )
288        ]
289
290    def debug_str(self):
291        lines = [f"var_ranges = {dict(self.var_ranges)}"]
292        lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
293        lines.extend(
294            [
295                block.debug_str(name)
296                for name, block in itertools.chain(
297                    [("body", self.root_block)], self.subblocks.items()
298                )
299            ]
300        )
301        return "\n".join(lines)
302
303    def is_memory_copy(self) -> bool:
304        """
305        True of this contains only a single loads and store.
306        Note, this could involve a layout change.
307        """
308        return (
309            len(self.memory_usage[MemoryUsageType.LOAD]) == 1
310            and len(self.memory_usage[MemoryUsageType.STORE]) == 1
311            and len(self.submodules) == 1  # get_index
312            and self.root_block.contains_only_ops(("load", "store"))
313        )
314
315    __repr__ = debug_str
316
317    def add_index_expr(
318        self,
319        expr: sympy.Expr,
320        mtype: MemoryUsageType,
321        buffer_name: Optional[str] = None,
322        mode: Optional[str] = None,
323    ):
324        name = self.indexing_exprs_name.get(expr)
325        if not name:
326            name = f"index{len(self.indexing_exprs)}"
327            self.indexing_exprs_name[expr] = name
328            self.indexing_exprs[name] = expr
329        self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode))
330        return name
331
332    def add_submodule(self, block, prefix):
333        """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
334        if prefix[-1].isnumeric() and prefix not in self.submodules:
335            name = prefix
336        else:
337            name = f"{prefix}{len(self.submodules)}"
338        self.submodules[name] = block
339        return name
340
341    def add_indirect(self, size):
342        var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars))
343        assert var not in self.indirect_var_ranges
344        self.indirect_vars.append(var)
345        self.indirect_var_ranges[var] = size
346        return var
347
348    def replace_indirect(self, old, new):
349        """Swap in a variable used in indirect indexing"""
350        if str(old) == str(new):
351            return
352        assert self.indexing is not None
353        self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
354
355    def get_index(self, name):
356        assert self.indexing is not None
357        return self.indexing[name]
358
359    def indexing_from_args(self, indices):
360        index = [*itertools.chain.from_iterable(indices)]
361        assert len(index) == len(self.var_ranges), (index, self.var_ranges)
362        assert all(
363            v not in self.var_ranges for v in index
364        ), f"{self.var_ranges=}, {indices=}"
365        replacements = dict(zip(self.var_ranges.keys(), index))
366        return {
367            name: sympy_subs(expr, replacements)
368            for name, expr in self.indexing_exprs.items()
369        }
370
371    def __call__(self, *indices):
372        self.indexing = self.indexing_from_args(indices)
373        result = self.root_block()
374        self.indexing = None
375        return result
376
377    def bind_set_indirect_shim(self, var, size, check, wrap_neg):
378        def set_indirect(new_var):
379            self.replace_indirect(
380                var, V.ops.indirect_indexing(new_var, size, check, wrap_neg)
381            )
382
383        set_indirect.clone = functools.partial(  # type: ignore[attr-defined]
384            LoopBody.bind_set_indirect_shim,
385            var=var,
386            size=size,
387            check=check,
388            wrap_neg=wrap_neg,
389        )
390        return set_indirect
391
392    def bind_scan_shim(self, combine_fn):
393        def shim(dtypes, values):
394            return V.ops.scan(dtypes, combine_fn, values)
395
396        shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn)  # type: ignore[attr-defined]
397        return shim
398
399    def bind_masked_shim(self, name):
400        def shim(mask, other):
401            return V.ops.masked(mask, self.subblocks[name], other)
402
403        shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name)  # type: ignore[attr-defined]
404        return shim
405
406
407class LoopBodyBlock:
408    """
409    Captures the body of a Loops subclass into an FX graph.
410    In normal cases there will be a 1:1 mapping between LoopBody and
411    LoopBodyBlock, hower in the case of ops.masked() the masked out
412    operations will manifest as an extra LoopBodyBlock.
413    """
414
415    def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
416        self.body = body
417
418        def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
419            return tracer.create_proxy(
420                "call_module",
421                "get_index",
422                (body.add_index_expr(expr, mtype, **kwargs),),
423                {},
424            )
425
426        class CaptureIndexing(V.WrapperHandler):  # type: ignore[name-defined]
427            self.name = "CaptureIndexing"
428
429            def load(self, name: str, index: sympy.Expr):
430                index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
431                return self._inner.load(name, index)
432
433            def load_seed(self, name: str, index: int):
434                assert isinstance(index, int)
435                body.add_index_expr(
436                    sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name
437                )
438                return self._inner.load_seed(name, index)
439
440            def store(self, name, index, value, mode=None):
441                index = add_index(
442                    index, MemoryUsageType.STORE, buffer_name=name, mode=mode
443                )
444                return self._inner.store(name, index, value, mode)
445
446            def store_reduction(self, name, index, value):
447                index = add_index(
448                    index, MemoryUsageType.STORE_REDUCTION, buffer_name=name
449                )
450                return self._inner.store_reduction(name, index, value)
451
452            def reduction(self, dtype, src_dtype, reduction_type, value):
453                result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
454                if "welford" in reduction_type:
455                    return tuple(result[i] for i in range(3))
456                return result
457
458            def index_expr(self, index, dtype):
459                if isinstance(index, (int, sympy.Integer)):
460                    return self._inner.constant(int(index), dtype)
461                index = add_index(index, MemoryUsageType.INDEX_EXPR)
462                return self._inner.index_expr(index, dtype)
463
464            def check_bounds(self, index, size, lower, upper):
465                index = add_index(index, MemoryUsageType.CHECK_BOUNDS)
466                size = add_index(size, MemoryUsageType.CHECK_BOUNDS)
467                return self._inner.check_bounds(index, size, lower, upper)
468
469            def bucketize(
470                self,
471                values,
472                offsets_name: str,
473                offsets_size: sympy.Expr,
474                indexing_dtype: torch.dtype,
475                right: bool,
476            ):
477                offsets_size = add_index(
478                    offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name
479                )
480                return self._inner.bucketize(
481                    values, offsets_name, offsets_size, indexing_dtype, right
482                )
483
484            @staticmethod
485            def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
486                """
487                Recursively capture the masked out body in another LoopBodyBlock
488                """
489                name = self.body.add_submodule(None, "masked_subblock")
490                self.body.submodules[name] = self.body.bind_masked_shim(name)
491                self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, [])
492                return tracer.create_proxy(
493                    "call_module", name, (mask_proxy, other_proxy), {}
494                )
495
496            @staticmethod
497            def scan(
498                dtype_proxy,
499                combine_fn: Callable[
500                    [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...]
501                ],
502                value_proxy,
503            ):
504                shim = self.body.bind_scan_shim(combine_fn)
505                name = self.body.add_submodule(shim, "scan")
506                result = tracer.create_proxy(
507                    "call_module",
508                    name,
509                    (dtype_proxy, value_proxy),
510                    {},
511                )
512                # Proxies are iterable, but some methods expect tuples/lists
513                return tuple(result[i] for i in range(len(value_proxy)))
514
515            def sort(self, dtypes, values, stable, descending):
516                result = self._inner.sort(dtypes, values, stable, descending)
517                # Proxies are iterable, but some methods expect tuples/lists
518                return tuple(result[i] for i in range(len(values)))
519
520            def frexp(self, value_proxy):
521                result = self._inner.frexp(value_proxy)
522                # Proxies are iterable, but some methods expect tuples/lists
523                return (result[0], result[1])
524
525            @staticmethod
526            def indirect_indexing(index_proxy, size, check=True, wrap_neg=True):
527                """
528                Flow data from tensors into indexing formulas.
529                Introduce a call_module to update the indexing.
530                """
531
532                var = self.body.add_indirect(size)
533                set_indirect = self.body.bind_set_indirect_shim(
534                    var, size, check, wrap_neg
535                )
536                tracer.create_proxy(
537                    "call_module",
538                    self.body.add_submodule(set_indirect, f"set_{var}"),
539                    (index_proxy,),
540                    {},
541                )
542                return var
543
544            @staticmethod
545            def output(result):
546                tracer.create_proxy("output", "output", (result,), {})
547
548        tracer = torch.fx.Tracer()
549        tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
550        proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
551
552        from .index_propagation import IndexPropagation
553        from .sizevars import SimplifyIndexing
554
555        handler: Any = SimplifyIndexing(
556            CaptureIndexing(proxy_ops), self.body.var_ranges
557        )
558        if config.constant_and_index_propagation:
559            handler = IndexPropagation(
560                handler, self.body.var_ranges, self.body.indirect_var_ranges
561            )
562
563        with V.set_ops_handler(handler):
564            # This indirection is just a cute way to get IndexPropagation to
565            # unwrap the return value.
566            ops.output(fn(*args))
567        self.graph = tracer.graph
568
569    def __call__(self):
570        graph = self.graph
571        submodules = self.body.submodules
572
573        return InterpreterShim(graph, submodules).run(V.get_ops_handler())
574
575    def debug_str(self, name="block"):
576        code = torch.fx.GraphModule(self.body.submodules, self.graph).code
577        return re.sub(
578            # strip `; del var0` suffixes to make output prettier
579            r";[^\n]*",
580            "",
581            code.strip().replace("def forward(", f"def {name}("),
582        )
583
584    def contains_only_ops(self, allowed_ops) -> bool:
585        return all(
586            node.target in allowed_ops
587            for node in self.graph.find_nodes(op="call_method")
588        )
589
590    def clone(self, body: LoopBody):
591        """Shallow copy with a new parent LoopBody"""
592        copy = LoopBodyBlock.__new__(LoopBodyBlock)
593        copy.__dict__.update({**self.__dict__, "body": body})
594        return copy
595