xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/joint_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import logging
4import typing
5from collections import Counter
6from typing import Any, Dict, List, Set, Union
7
8import torch
9import torch._guards
10import torch.utils._pytree as pytree
11from torch._inductor.constant_folding import ConstantFolder
12from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict
13from torch.fx.experimental.symbolic_shapes import statically_known_true
14from torch.fx.passes.graph_transform_observer import GraphTransformObserver
15from torch.multiprocessing.reductions import StorageWeakRef
16
17from ...utils._ordered_set import OrderedSet
18from .. import config
19from ..pattern_matcher import (
20    CallFunction,
21    init_once_fakemode,
22    KeywordArg,
23    Match,
24    MULTIPLE,
25    PatternMatcherPass,
26    register_graph_pattern,
27    stable_topological_sort,
28)
29from .replace_random import replace_random_passes
30
31
32log = logging.getLogger(__name__)
33patterns = PatternMatcherPass()
34aten = torch.ops.aten
35prims = torch.ops.prims
36
37pass_patterns = [
38    patterns,
39    PatternMatcherPass(),
40]
41
42
43@init_once_fakemode
44def lazy_init():
45    from .fuse_attention import _sfdp_init
46    from .misc_patterns import _misc_patterns_init
47    from .pad_mm import _pad_mm_init
48
49    _pad_mm_init()
50    _sfdp_init()
51    _misc_patterns_init()
52
53
54def remove_no_ops(
55    gm: torch.fx.GraphModule, zeros: Set[torch.fx.Node], ones: Set[torch.fx.Node]
56):
57    with torch.utils._python_dispatch._disable_current_modes():
58        "Removes no-ops: (+ 0, - 0, * 1, / 1)"
59        graph = gm.graph
60
61        def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")):
62            if any(not isinstance(t, torch.Tensor) for t in (t1, t2)):
63                return False
64            for field in fields:
65                if getattr(t1, field) != getattr(t2, field):
66                    return False
67            return True
68
69        def replace_no_op(node, replace_input_index):
70            replacement = node.args[replace_input_index]
71
72            # https://github.com/pytorch/pytorch/issues/86128 causes
73            # non-Tensor inputs even for ops with only Tensor inputs.
74            # TODO - decompose/type promote to avoid this
75            if not all(isinstance(arg, torch.fx.Node) for arg in node.args):
76                return
77
78            if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]):
79                if fake_tensors_eq(
80                    node.meta["val"],
81                    replacement.meta["val"],
82                    ("shape", "device"),
83                ):
84                    with graph.inserting_after(node):
85                        replacement = graph.call_function(
86                            torch.ops.prims.convert_element_type.default,
87                            args=(replacement, node.meta["val"].dtype),
88                        )
89                else:
90                    return
91
92            node.replace_all_uses_with(replacement)
93            replacement.meta.update(node.meta)
94            graph.erase_node(node)
95
96        for node in graph.find_nodes(op="call_function", target=aten.add.Tensor):
97            # TODO handle Tensor-Scalar adds, it's a different schema
98            if len(node.args) == 2:
99                if (
100                    not any(e in zeros for e in node.args)
101                    or node.kwargs.get("alpha", 1) != 1
102                ):
103                    continue
104
105                replace_index = 1 if node.args[0] in zeros else 0
106                replace_no_op(node, replace_index)
107
108        for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor):
109            if len(node.args) == 2:
110                if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1:
111                    continue
112
113                replace_no_op(node, 0)
114
115        for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor):
116            if len(node.args) == 2:
117                if not any(e in ones for e in node.args):
118                    continue
119
120                replace_input_index = 1 if node.args[0] in ones else 0
121                replace_no_op(node, replace_input_index)
122
123        for node in graph.find_nodes(op="call_function", target=aten.div.Tensor):
124            if len(node.args) == 2 and node.args[1] in ones:
125                replace_no_op(node, 0)
126
127        # meta tensors returned from the graph have no data and can be replaced with empty_strided
128        for output_node in graph.find_nodes(op="output"):
129            had_meta_return = False
130
131            def visit(n):
132                nonlocal had_meta_return
133                val = n.meta.get("val")
134                if isinstance(val, torch.Tensor) and val.device.type == "meta":
135                    with graph.inserting_before(output_node):
136                        n.replace_all_uses_with(
137                            graph.call_function(
138                                torch.ops.aten.empty_strided.default,
139                                args=(val.size(), val.stride()),
140                                kwargs={"dtype": val.dtype, "device": val.device},
141                            )
142                        )
143                    had_meta_return = True
144
145            torch.fx.map_arg(output_node.args, visit)
146            if had_meta_return:
147                graph.eliminate_dead_code()
148
149
150def remove_redundant_views(gm: torch.fx.GraphModule):
151    """
152    Removes redundant views by reusing existing ones.
153    """
154    with torch.utils._python_dispatch._disable_current_modes():
155        # A dictionary mapping a tensor to all aliased views.
156        views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
157        graph = gm.graph
158
159        for node in graph.find_nodes(
160            op="call_function", target=torch.ops.aten.view.dtype
161        ):
162            src = node.args[0]
163            to_type = node.args[1]
164            existing_views = views.get(src)
165            is_needed = True
166
167            if existing_views:
168                # Replace the view with the an existing view if available.
169                alias = existing_views.get(to_type)
170                if alias:
171                    is_needed = False
172                    node.replace_all_uses_with(alias)
173                    alias.meta.update(node.meta)
174                    graph.erase_node(node)
175            else:
176                from_type = src.meta["val"].dtype
177                existing_views = {from_type: src}
178                views[src] = existing_views
179
180            if is_needed:
181                # Save the new alias but do not replace existing one.
182                existing_views.setdefault(to_type, node)
183                views[node] = existing_views
184
185        # Clean up unused views.
186        while True:
187            unused_views = [alias for alias in views if not alias.users]
188            if len(unused_views) == 0:
189                break
190            for unused in unused_views:
191                views.pop(unused)
192                graph.erase_node(unused)
193
194
195class UniformValueConstantFolder(ConstantFolder):
196    """
197    Runs constant folding and replaces tensors that have a unifrom value
198    with a tensor constructor call: aten.full([shape], value, ...)
199    """
200
201    def __init__(self, gm, skip_constructors=False) -> None:
202        super().__init__(gm, skip_constructors)
203        self.node_storages_ptrs: Dict[torch.fx.Node, int] = {}
204        self.constant_data_ptrs: Dict[torch.fx.Node, StorageWeakRef] = {}
205        # we may constant fold a tensor which in the graph has a sym size
206        # see: [constant folding refining of symints]
207        self.node_replacements_shapes: Dict[torch.fx.Node, List[int]] = {}
208
209        # initialize symint -> node mapping so that we can
210        # use symint nodes in full constructors
211        self.symint_nodes = _SymHashingDict()
212        for n in self.module.graph.nodes:
213            if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
214                self.symint_nodes[n.meta["val"]] = n
215
216        # reference from torch/_funtorch/partitioners.py:get_default_op_list
217        self.view_op_packets = [
218            aten.squeeze,
219            aten.unsqueeze,
220            aten.alias,
221            aten.view,
222            aten.slice,
223            aten.t,
224            prims.broadcast_in_dim,
225            aten.expand,
226            aten.as_strided,
227            aten.permute,
228        ]
229
230        self.indexing_op_packets = {
231            aten.slice,
232        }
233
234    def _support_dynamic_shape(self):
235        return True
236
237    def insertable_tensor_check(self, t: torch.Tensor) -> bool:
238        return True
239
240    def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
241        self.node_replacements[node] = tensor.flatten()[0].item()
242        self.node_replacements_shapes[node] = node.meta["val"].shape
243        self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage())
244
245    def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
246        for n in self.module.graph.find_nodes(op="placeholder"):
247            if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
248                env[n] = n.meta["val"]
249            else:
250                env[n] = self.unknown_value
251
252    def _deduce_value(self, node: torch.fx.Node):
253        # deduce value for full-like nodes
254        # 1. for constructors, substitute value is a tensor of size [1]
255        # 2. for view ops/indexing, substitute value is the same as the input
256        # 3. for pointwise ops, run node to get the substitute value
257        # 4. deal with some special ops
258        # otherwise, stop deduce value and return unknown value
259
260        # TODO: cat, more indexing
261        # TODO - do on cpu to avoid syncs
262
263        # single-elem attrs
264        if node.op == "get_attr" or (
265            node.op == "call_function"
266            and node.target == torch.ops.aten.lift_fresh_copy.default
267        ):
268            out = super(ConstantFolder, self).run_node(node)
269            if isinstance(out, torch.Tensor) and out.numel() == 1:
270                return out
271
272        # handle device_put op
273        if node.target == prims.device_put.default:
274            return super(ConstantFolder, self).run_node(node)
275
276        # constructors ops
277        if (
278            node.op == "call_function"
279            and node.target == aten.full.default
280            and len(node.args) == 2
281        ):
282            args, kwargs = self.fetch_args_kwargs_from_env(node)
283            new_args = [[1], args[1]]
284            return aten.full.default(*new_args, **node.kwargs)
285
286        # handle before view ops because this changes value
287        if node.target == aten.view.dtype:
288            return super(ConstantFolder, self).run_node(node)
289
290        # view ops, return input tensor, the first argument
291        if hasattr(node.target, "overloadpacket") and (
292            node.target.overloadpacket in self.view_op_packets
293            or node.target.overloadpacket in self.indexing_op_packets
294        ):
295            assert isinstance(node.args[0], torch.fx.Node)
296            return self.env[node.args[0]]
297
298        # we don't want to return unknown value for symints so that we can
299        # still constant fold through their use in constructors or views
300        # if we see them in a pointwise node (e.g., tensor * symint)
301        # we will bail
302        if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt):
303            return node.meta["val"]
304
305        # pointwise ops
306        if isinstance(node.target, torch._ops.OpOverload) and (
307            torch.Tag.pointwise in node.target.tags
308            or node.target is torch.ops.aten.scalar_tensor.default
309        ):
310            args, kwargs = self.fetch_args_kwargs_from_env(node)
311            flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
312
313            if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs):
314                return self.unknown_value
315
316            # we run the ops with dim 1, so remove memory_format to avoid error
317            kwargs = dict(kwargs)
318            kwargs.pop("memory_format", None)
319
320            return node.target(*args, **kwargs)
321
322        return self.unknown_value
323
324
325def constant_fold_uniform_value(gm: torch.fx.GraphModule):
326    with torch.utils._python_dispatch._disable_current_modes():
327        "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops."
328        aten = torch.ops.aten
329
330        # Constant folding can leak memory, especially with repeated compilation, so we are only going to
331        # remove constants which can be replaced with a constructor.
332        cf = UniformValueConstantFolder(gm)
333        cf.run()
334
335        node_replacements = cf.node_replacements
336
337        # note: [constant folding refining of symints]
338        # constant folding will partially evaluate a graph such that values which have dependencies which
339        # are entirely known at compile time may also become compile time constants. in some cases,
340        # this will include symints which we had not yet previously deduced are guaranteed a
341        # constant value and is then deduced in constant folding. an example is:
342        # unbacked_symint_eq_11 = torch.full((), 11).item()
343        # torch.full((unbacked_symint_eq_11,), 0)
344        node_replacements_shapes = cf.node_replacements_shapes
345
346        graph = gm.graph
347
348        zeros = set()
349        ones = set()
350
351        # Got failures in `test_is_set_to_cuda` if we change aliasing on constants,
352        # so just constant-ify if a Tensor is unaliased
353        constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter()
354
355        for node in cf.node_replacements:
356            constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1
357
358        for node, value in node_replacements.items():
359            # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now
360            # hasn't shown up to be important yet
361            if "val" not in node.meta:
362                # This can only happen in AOTI
363                continue
364
365            fake_tensor = node.meta["val"]
366            if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
367                continue
368
369            # TODO - not sure about lossy uint->python value->uint conversions
370            if fake_tensor.dtype in (
371                torch.uint8,
372                torch.uint16,
373                torch.uint32,
374                torch.uint64,
375            ):
376                continue
377
378            if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1:
379                continue
380
381            with graph.inserting_after(node):
382                # the conversion from tensor and back to value can be lossy, just use the original full ctor value
383                if (
384                    node.op == "call_function"
385                    and node.target == aten.full.default
386                    and len(node.args) == 2
387                ):
388                    value = node.args[1]
389
390                # refines symints, see [constant folding refining of symints] above
391                for runtime_size, compile_time_size in zip(
392                    node_replacements_shapes[node], fake_tensor.shape
393                ):
394                    torch._check(runtime_size == compile_time_size)
395
396                # replace SymInt as Node before creating a new full node
397                # e.g. (1, s0) -> (1, arg0_1)
398                node_shape = node_replacements_shapes[node]
399                if not all(
400                    not isinstance(s, torch.SymInt) or s in cf.symint_nodes
401                    for s in node_shape
402                ):
403                    continue
404
405                shapes = [
406                    cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s
407                    for s in node_replacements_shapes[node]
408                ]
409
410                # zeros and ones just get traced into full, so we insert those
411                new_node = graph.call_function(
412                    aten.full.default,
413                    args=(shapes, value),
414                    kwargs={
415                        "dtype": fake_tensor.dtype,
416                        "layout": torch.strided,
417                        "device": fake_tensor.device,
418                        "pin_memory": False,
419                    },
420                )
421
422                new_node.meta.update(node.meta)
423                node.replace_all_uses_with(new_node)
424                graph.erase_node(node)
425
426                if value == 0:
427                    zeros.add(new_node)
428                elif value == 1:
429                    ones.add(new_node)
430
431        remove_no_ops(gm, zeros, ones)
432        remove_redundant_views(gm)
433
434
435def joint_graph_passes(graph: torch.fx.GraphModule):
436    """
437    Run FX transformations on the joint forwards+backwards graph.
438    """
439    lazy_init()
440    count = 0
441    if config.joint_custom_pre_pass is not None:
442        with GraphTransformObserver(
443            graph, "joint_custom_pre_pass", config.trace.log_url_for_graph_xform
444        ):
445            config.joint_custom_pre_pass(graph.graph)
446            count += 1
447
448    from .post_grad import remove_noop_ops
449
450    remove_noop_ops(graph.graph)
451
452    if config.joint_graph_constant_folding:
453        with GraphTransformObserver(
454            graph, "constant_fold_uniform_value", config.trace.log_url_for_graph_xform
455        ):
456            constant_fold_uniform_value(graph)
457
458    if config.pattern_matcher:
459        for patterns in pass_patterns:
460            count += patterns.apply(graph.graph)  # type: ignore[arg-type]
461
462    if not config.fallback_random:
463        count += replace_random_passes(graph)
464
465    if config.joint_custom_post_pass is not None:
466        with GraphTransformObserver(
467            graph, "joint_custom_post_pass", config.trace.log_url_for_graph_xform
468        ):
469            config.joint_custom_post_pass(graph.graph)
470            count += 1
471
472    if count:
473        stable_topological_sort(graph.graph)
474        graph.graph.lint()
475        graph.recompile()
476    return graph
477
478
479@register_graph_pattern(
480    CallFunction(
481        torch.ops.prims.iota.default,
482        KeywordArg("length"),
483        start=KeywordArg("start"),
484        step=KeywordArg("step"),
485        dtype=KeywordArg("dtype"),
486        device=KeywordArg("device"),
487        requires_grad=KeywordArg("requires_grad"),
488    ),
489    pass_dict=patterns,
490)
491def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad):
492    """
493    Eager supports:
494
495        aten.index(cuda_tensor, torch.arange(..., device="cpu"))
496
497    But this results in an implicit host-device-copy and breaks cudagraphs.
498    Rewrite the arange to use CUDA.
499    """
500    (node,) = match.nodes
501    user_devices: OrderedSet[torch.device] = OrderedSet()
502    for user in node.users:
503        if (
504            user.op == "call_function"
505            and user.target in (aten.index.Tensor, aten.index_put.default)
506            and hasattr(user.meta.get("val"), "device")
507        ):
508            user_devices.add(user.meta["val"].device)  # type: ignore[union-attr]
509        else:
510            return  # bail out
511
512    if len(user_devices) == 1 and "val" in node.meta:
513        (user_device,) = user_devices
514        if device.type != user_device.type:
515            repl = match.graph.call_function(
516                torch.ops.prims.iota.default,
517                (length,),
518                {
519                    "start": start,
520                    "step": step,
521                    "dtype": dtype,
522                    "device": user_device,
523                    "requires_grad": requires_grad,
524                },
525            )
526            repl.meta.update(node.meta)
527            repl.meta["val"] = repl.meta["val"].to(user_device)
528            node.replace_all_uses_with(repl)
529            match.erase_nodes()
530
531
532@register_graph_pattern(
533    CallFunction(
534        torch.ops.prims.convert_element_type.default,
535        CallFunction(
536            torch.ops.prims.convert_element_type.default,
537            KeywordArg("arg"),
538            KeywordArg("dtype1"),
539        ),
540        KeywordArg("dtype2"),
541    ),
542    pass_dict=patterns,
543)
544def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype):
545    """Remove chain of dtype conversions often created by AMP"""
546    graph = match.graph
547    node = match.output_node()
548    allowed = {torch.float16, torch.bfloat16, torch.float32, torch.float64}
549    if dtype1 in allowed and dtype2 in allowed:
550        repl = graph.call_function(
551            torch.ops.prims.convert_element_type.default, (arg, dtype2)
552        )
553        repl.meta.update(node.meta)
554        node.replace_all_uses_with(repl)
555        match.erase_nodes()
556
557
558@register_graph_pattern(
559    CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
560    pass_dict=patterns,
561)
562def pointless_view(match: Match, arg, size):
563    """Remove no-op view"""
564    node = match.output_node()
565    arg_size = list(node.args[0].meta["val"].shape)  # type: ignore[union-attr]
566    if size == arg_size:
567        node.replace_all_uses_with(node.args[0])  # type: ignore[arg-type]
568        match.erase_nodes()
569
570
571# When softmax is used with temperature or other scaling, we get the pattern
572#
573#   scale(x) - scale(x).amax(dim, keepdim=True)
574#
575# which is expected to be at most zero, but we may end up with numerical
576# discrepancies # between the recomputed values of scale(x) inside and out
577# of the reduction, # depending on compiler optimizations, e.g. use of fma
578# instructions.
579#
580# Here we replace it with the mathematically equivalent,
581#
582#   scale(x - x.amax(dim, keepdim=True))
583#
584# which is more stable as we only compute the scaling once.
585#
586# NOTE: This pattern must come after fused attention matching!
587
588
589def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False):
590    # Allow matching inp * other and other * input
591    if reverse:
592        scaled = CallFunction(
593            linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE
594        )
595    else:
596        scaled = CallFunction(
597            linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE
598        )
599    if to_dtype:
600        scaled = CallFunction(
601            prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE
602        )
603    amax = CallFunction(
604        aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim")
605    )
606    return CallFunction(aten.sub.Tensor, scaled, amax)
607
608
609def _other_is_broadcasted_in_dim(match):
610    # Check that the scaling factor is constant across the reduction dim,
611    # so scaling doesn't change which index corresponds to the maximum value
612    other = match.kwargs["other"]
613    if isinstance(other, (int, float)):
614        return True
615
616    inp = match.kwargs["inp"]
617    if not all(isinstance(x, torch.fx.Node) for x in (inp, other)):
618        return False
619
620    inp_example = inp.meta["val"]
621    other_example = other.meta["val"]
622    if isinstance(other_example, (torch.SymInt, torch.SymFloat)):
623        return True
624
625    if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)):
626        return False
627
628    inp_ndim = inp_example.ndim
629    other_shape = other_example.shape
630    if inp_ndim < len(other_shape):
631        return False
632
633    # Pad other_shape to the same ndim as inp
634    other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape)
635
636    dim = match.kwargs["dim"]
637    if isinstance(dim, int):
638        dim = (dim,)
639
640    return all(statically_known_true(other_shape[d] == 1) for d in dim)
641
642
643def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
644    def repl(inp, other):
645        if dtype is not None:
646            inp = inp.to(dtype)
647
648        sign: Union[int, float, torch.Tensor]
649        if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
650            sign = 1 if other >= 0 else -1
651        else:
652            one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
653            sign = torch.where(other >= 0, one, -one)
654
655        inp = inp * sign
656        max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
657        return (inp - max_) * (sign * other)
658
659    match.replace_by_example(repl, [inp, other])
660
661
662for reverse, to_dtype in itertools.product((False, True), repeat=2):
663    register_graph_pattern(
664        _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype),
665        pass_dict=pass_patterns[1],
666        extra_check=_other_is_broadcasted_in_dim,
667    )(mul_softmax_pattern)
668
669
670def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
671    def repl(inp, other):
672        if dtype is not None:
673            inp = inp.to(dtype)
674
675        sign: Union[int, float, torch.Tensor]
676        if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
677            sign = 1 if other >= 0 else -1
678        else:
679            one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device)
680            sign = torch.where(other >= 0, one, -one)
681
682        inp = inp * sign
683        max_ = torch.amax(inp, dim=dim, keepdim=keepdim)
684        return (inp - max_) / (sign * other)
685
686    match.replace_by_example(repl, [inp, other])
687
688
689for to_dtype in (False, True):
690    register_graph_pattern(
691        _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype),
692        pass_dict=pass_patterns[1],
693        extra_check=_other_is_broadcasted_in_dim,
694    )(div_softmax_pattern)
695