xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/export_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import types
3
4import torch
5import torch.nn.functional as F
6from torch.ao.quantization.utils import _assert_and_get_unique_device
7
8
9__all__ = [
10    "model_is_exported",
11]
12
13
14class _WrapperModule(torch.nn.Module):
15    """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
16    are trying to export a callable.
17    """
18
19    def __init__(self, fn):
20        super().__init__()
21        self.fn = fn
22
23    def forward(self, *args, **kwargs):
24        """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
25        return self.fn(*args, **kwargs)
26
27
28def model_is_exported(m: torch.nn.Module) -> bool:
29    """
30    Return True if the `torch.nn.Module` was exported, False otherwise
31    (e.g. if the model was FX symbolically traced or not traced at all).
32    """
33    return isinstance(m, torch.fx.GraphModule) and any(
34        "val" in n.meta for n in m.graph.nodes
35    )
36
37
38def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
39    """
40    Switch dropout patterns in the model between train and eval modes.
41
42    Dropout has different behavior in train vs eval mode. For exported models,
43    however, calling `model.train()` or `model.eval()` does not automatically switch
44    the dropout behavior between the two modes, so here we need to rewrite the aten
45    dropout patterns manually to achieve the same effect.
46
47    See https://github.com/pytorch/pytorch/issues/103681.
48    """
49    # Avoid circular dependencies
50    from .utils import _get_aten_graph_module_for_pattern
51
52    # Needed to ensure subgraph matches are self-contained
53    m.graph.eliminate_dead_code()
54    m.recompile()
55
56    for inplace in [False, True]:
57
58        def dropout_train(x):
59            return F.dropout(x, p=0.5, training=True, inplace=inplace)
60
61        def dropout_eval(x):
62            return F.dropout(x, p=0.5, training=False, inplace=inplace)
63
64        example_inputs = (torch.randn(1),)
65        if train_to_eval:
66            match_pattern = _get_aten_graph_module_for_pattern(
67                _WrapperModule(dropout_train), example_inputs
68            )
69            replacement_pattern = _get_aten_graph_module_for_pattern(
70                _WrapperModule(dropout_eval), example_inputs
71            )
72        else:
73            match_pattern = _get_aten_graph_module_for_pattern(
74                _WrapperModule(dropout_eval), example_inputs
75            )
76            replacement_pattern = _get_aten_graph_module_for_pattern(
77                _WrapperModule(dropout_train), example_inputs
78            )
79
80        from torch.fx.subgraph_rewriter import replace_pattern_with_filters
81
82        replace_pattern_with_filters(
83            m,
84            match_pattern,
85            replacement_pattern,
86            match_filters=[],
87            ignore_literals=True,
88        )
89        m.recompile()
90
91
92def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
93    """
94    Switch batchnorm patterns in the model between train and eval modes.
95
96    Batchnorm has different behavior in train vs eval mode. For exported models,
97    however, calling `model.train()` or `model.eval()` does not automatically switch
98    the batchnorm behavior between the two modes, so here we need to rewrite the aten
99    batchnorm patterns manually to achieve the same effect.
100    """
101    # TODO(Leslie): This function still fails to support custom momentum and eps value.
102    # Enable this support in future updates.
103
104    # Avoid circular dependencies
105    from .utils import _get_aten_graph_module_for_pattern
106
107    # Needed to ensure subgraph matches are self-contained
108    m.graph.eliminate_dead_code()
109    m.recompile()
110
111    def bn_train(
112        x: torch.Tensor,
113        bn_weight: torch.Tensor,
114        bn_bias: torch.Tensor,
115        bn_running_mean: torch.Tensor,
116        bn_running_var: torch.Tensor,
117    ):
118        return F.batch_norm(
119            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
120        )
121
122    def bn_eval(
123        x: torch.Tensor,
124        bn_weight: torch.Tensor,
125        bn_bias: torch.Tensor,
126        bn_running_mean: torch.Tensor,
127        bn_running_var: torch.Tensor,
128    ):
129        return F.batch_norm(
130            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
131        )
132
133    example_inputs = (
134        torch.randn(1, 1, 3, 3),  # x
135        torch.randn(1),  # bn_weight
136        torch.randn(1),  # bn_bias
137        torch.randn(1),  # bn_running_mean
138        torch.randn(1),  # bn_running_var
139    )
140
141    device = _assert_and_get_unique_device(m)
142    is_cuda = device is not None and device.type == "cuda"
143    bn_train_aten = _get_aten_graph_module_for_pattern(
144        _WrapperModule(bn_train),
145        example_inputs,
146        is_cuda,
147    )
148    bn_eval_aten = _get_aten_graph_module_for_pattern(
149        _WrapperModule(bn_eval),
150        example_inputs,
151        is_cuda,
152    )
153
154    if train_to_eval:
155        match_pattern = bn_train_aten
156        replacement_pattern = bn_eval_aten
157    else:
158        match_pattern = bn_eval_aten
159        replacement_pattern = bn_train_aten
160
161    from torch.fx.subgraph_rewriter import replace_pattern_with_filters
162
163    replace_pattern_with_filters(
164        m,
165        match_pattern,
166        replacement_pattern,
167        match_filters=[],
168        ignore_literals=True,
169    )
170    m.recompile()
171
172
173# TODO: expose these under this namespace?
174def _move_exported_model_to_eval(model: torch.fx.GraphModule):
175    """
176    Move an exported GraphModule to eval mode.
177
178    This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
179    QAT users should call this before performing inference on the model.
180    """
181    _replace_dropout(model, train_to_eval=True)
182    _replace_batchnorm(model, train_to_eval=True)
183    return model
184
185
186def _move_exported_model_to_train(model: torch.fx.GraphModule):
187    """
188    Move an exported GraphModule to train mode.
189
190    This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
191    QAT users should call this before performing training on the model.
192    """
193    _replace_dropout(model, train_to_eval=False)
194    _replace_batchnorm(model, train_to_eval=False)
195    return model
196
197
198def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
199    """
200    Allow users to call `model.train()` and `model.eval()` on an exported model,
201    but with the effect of changing behavior between the two modes limited to special
202    ops only, which are currently dropout and batchnorm.
203
204    Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
205    does in eager models, but only provides an approximation. In particular, user code
206    branching on `training` flag will not function correctly in general because the branch
207    is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
208    that have different train/eval behavior will also not be converted properly.
209    """
210
211    def _train(self, mode: bool = True):
212        if mode:
213            _move_exported_model_to_train(self)
214        else:
215            _move_exported_model_to_eval(self)
216
217    def _eval(self):
218        _move_exported_model_to_eval(self)
219
220    model.train = types.MethodType(_train, model)  # type: ignore[method-assign]
221    model.eval = types.MethodType(_eval, model)  # type: ignore[method-assign]
222    return model
223