xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/freezing_patterns.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3
4import torch
5from torch._inductor.compile_fx import fake_tensor_prop
6
7from ..._dynamo.utils import counters
8from .. import config
9from ..pattern_matcher import (
10    _return_true,
11    CallFunction,
12    fwd_only,
13    Ignored,
14    init_once_fakemode,
15    KeywordArg,
16    Match,
17    PatternMatcherPass,
18    register_graph_pattern,
19    register_replacement,
20    stable_topological_sort,
21)
22
23
24aten = torch.ops.aten
25
26# First pass_patterns[0] are applied, then [1], then [2]
27pass_patterns = [
28    PatternMatcherPass(),
29    PatternMatcherPass(),
30    PatternMatcherPass(),
31]
32
33binary_folding_pass = PatternMatcherPass()
34
35
36def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs):
37    """
38    Passes that are applied to the graph to freeze pass.
39    """
40
41    from ..freezing import constant_fold
42
43    lazy_init()
44    # We need a few rounds of binary folding to get rid of all the
45    # unnecessary nodes, but may need a good method to chose the rounds number.
46    # works like: conv+binary+binary.
47    binary_folding = counters["inductor"]["binary_folding"]
48    fake_tensor_prop(gm, aot_example_inputs, True)
49
50    torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_convs(gm)
51    for _ in range(4):
52        constant_fold(gm)
53        # Make sure meta['val'] is properly set for all nodes
54        fake_tensor_prop(gm, aot_example_inputs, True)
55        binary_folding_pass.apply(gm.graph)  # type: ignore[arg-type]
56        # If we don't have binary folding, we don't need to run the pass again.
57        # TODO: remove the need to run fake_tensor_prop on the whole model.
58        if counters["inductor"]["binary_folding"] == binary_folding:
59            break
60        binary_folding = counters["inductor"]["binary_folding"]
61
62    torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_convs(gm)
63
64    constant_fold(gm)
65    fake_tensor_prop(gm, aot_example_inputs, True)
66
67    for pattern in pass_patterns:
68        pattern.apply(gm.graph)  # type: ignore[arg-type]
69
70    # The CPU weight packing always assume the conv's weight is channels last,
71    # So make sure the layout_optimization is on when doing it.
72    if (
73        torch._C._has_mkldnn
74        and config.cpp.weight_prepack
75        and config.layout_optimization
76    ):
77        from .mkldnn_fusion import _eliminate_duplicate_packed_nodes
78
79        _eliminate_duplicate_packed_nodes(gm)
80
81    stable_topological_sort(gm.graph)
82    gm.recompile()
83    gm.graph.lint()
84
85
86@init_once_fakemode
87def lazy_init():
88    if torch._C._has_mkldnn and config.cpp.weight_prepack:
89        from .mkldnn_fusion import _mkldnn_weight_pack_init
90
91        _mkldnn_weight_pack_init()
92
93    from .binary_folding import binary_folding_init
94
95    addmm_patterns_init()
96    binary_folding_init()
97
98
99def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
100    return register_graph_pattern(
101        pattern,
102        extra_check=extra_check,
103        pass_dict=pass_patterns[pass_number],
104    )
105
106
107def register_binary_folding_pattern(pattern, extra_check=_return_true):
108    return register_graph_pattern(
109        pattern,
110        extra_check=extra_check,
111        pass_dict=binary_folding_pass,
112    )
113
114
115@functools.lru_cache(None)
116def addmm_patterns_init():
117    if torch.cuda.is_available():
118        # workaround https://github.com/pytorch/pytorch/issues/97894
119        device = "cuda"
120    else:
121        device = "cpu"
122    val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False)
123
124    def check_concat_weights(match):
125        weight_inputs = ["w1", "w2"]
126        if "w3" in match.kwargs:
127            weight_inputs.append("w3")
128
129        equal_shape_inputs = [weight_inputs]
130
131        if "b1" in match.kwargs:
132            bias_inputs = ["b1", "b2"]
133            if "b3" in match.kwargs:
134                bias_inputs.append("b3")
135
136            equal_shape_inputs.append(bias_inputs)
137
138        for equal_shape_group in equal_shape_inputs:
139            inps = [match.kwargs[name] for name in equal_shape_group]
140
141            if not all(
142                inp.op == "get_attr"
143                and inp.meta["val"].shape == inps[0].meta["val"].shape
144                for inp in inps
145            ):
146                return False
147
148        return True
149
150    def matmul_fuse_pattern(inp, w1, w2, w3):
151        return (inp @ w1, inp @ w2, inp @ w3)
152
153    def matmul_replacement(inp, w1, w2, w3):
154        cat_t = torch.cat((w1, w2, w3), dim=1)
155        mm = inp @ cat_t
156        return mm.chunk(3, dim=1)
157
158    register_replacement(
159        matmul_fuse_pattern,
160        matmul_replacement,
161        [val(), val(), val(), val()],
162        fwd_only,
163        pass_patterns[0],
164        extra_check=check_concat_weights,
165        exclusive_arg_names=("w1", "w2", "w3"),
166    )
167
168    def matmul_fuse_pattern_two(inp, w1, w2):
169        return (inp @ w1, inp @ w2)
170
171    def matmul_replacement_two(inp, w1, w2):
172        cat_t = torch.cat((w1, w2), dim=1)
173        mm = inp @ cat_t
174        return mm.chunk(2, dim=1)
175
176    register_replacement(
177        matmul_fuse_pattern_two,
178        matmul_replacement_two,
179        [val(), val(), val()],
180        fwd_only,
181        pass_patterns[0],
182        extra_check=check_concat_weights,
183        exclusive_arg_names=("w1", "w2"),
184    )
185
186    def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3):
187        return (
188            aten.addmm(b1, inp, w1),
189            aten.addmm(b2, inp, w2),
190            aten.addmm(b3, inp, w3),
191        )
192
193    def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3):
194        cat_w = torch.cat((w1, w2, w3), dim=1)
195        cat_b = torch.cat((b1, b2, b3))
196        return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1)
197
198    register_replacement(
199        addmm_fuse_pattern_second,
200        addmm_fuse_replacement_second,
201        [val() for _ in range(7)],
202        fwd_only,
203        pass_patterns[0],
204        extra_check=check_concat_weights,
205        exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"),
206    )
207
208
209def same_dtype(match):
210    return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"]
211
212
213@register_graph_pattern(
214    CallFunction(
215        torch.ops.prims.convert_element_type.default,
216        Ignored(),
217        KeywordArg("dtype"),
218    ),
219    pass_dict=pass_patterns[0],
220    extra_check=same_dtype,
221)
222def unnecessary_dtype_convert(match: Match, **kwargs):
223    """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding"""
224    graph = match.graph
225    node = match.output_node()
226    node.replace_all_uses_with(node.args[0])  # type: ignore[arg-type]
227    graph.erase_node(node)
228