xref: /aosp_15_r20/external/pytorch/torch/jit/_freeze.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Freezing.
3
4This is not intended to be imported directly; please use the exposed
5functionalities in `torch.jit`.
6"""
7
8from typing import List, Optional
9
10import torch
11from torch.jit._script import RecursiveScriptModule, ScriptModule
12
13
14def freeze(
15    mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True
16):
17    r"""Freeze ScriptModule, inline submodules, and attributes as constants.
18
19    Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
20    module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
21    By default, `forward` will be preserved, as well as attributes & methods specified in
22    `preserved_attrs`. Additionally, any attribute that is modified within a preserved
23    method will be preserved.
24
25    Freezing currently only accepts ScriptModules that are in eval mode.
26
27    Freezing applies generic optimization that will speed up your model regardless of machine.
28    To further optimize using server-specific settings, run `optimize_for_inference` after
29    freezing.
30
31    Args:
32        mod (:class:`ScriptModule`): a module to be frozen
33        preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
34            Attributes modified in preserved methods will also be preserved.
35        optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
36            preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
37
38    Returns:
39        Frozen :class:`ScriptModule`.
40
41    Example (Freezing a simple module with a Parameter):
42
43    .. testcode::
44        import torch
45        class MyModule(torch.nn.Module):
46            def __init__(self, N, M):
47                super().__init__()
48                self.weight = torch.nn.Parameter(torch.rand(N, M))
49                self.linear = torch.nn.Linear(N, M)
50
51            def forward(self, input):
52                output = self.weight.mm(input)
53                output = self.linear(output)
54                return output
55
56        scripted_module = torch.jit.script(MyModule(2, 3).eval())
57        frozen_module = torch.jit.freeze(scripted_module)
58        # parameters have been removed and inlined into the Graph as constants
59        assert len(list(frozen_module.named_parameters())) == 0
60        # See the compiled graph as Python code
61        print(frozen_module.code)
62
63    Example (Freezing a module with preserved attributes)
64
65    .. testcode::
66        import torch
67        class MyModule2(torch.nn.Module):
68            def __init__(self) -> None:
69                super().__init__()
70                self.modified_tensor = torch.tensor(10.)
71                self.version = 1
72
73            def forward(self, input):
74                self.modified_tensor += 1
75                return input + self.modified_tensor
76
77        scripted_module = torch.jit.script(MyModule2().eval())
78        frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
79        # we've manually preserved `version`, so it still exists on the frozen module and can be modified
80        assert frozen_module.version == 1
81        frozen_module.version = 2
82        # `modified_tensor` is detected as being mutated in the forward, so freezing preserves
83        # it to retain model semantics
84        assert frozen_module(torch.tensor(1)) == torch.tensor(12)
85        # now that we've run it once, the next result will be incremented by one
86        assert frozen_module(torch.tensor(1)) == torch.tensor(13)
87
88    Note:
89        Freezing submodule attributes is also supported:
90        frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])
91
92    Note:
93        If you're not sure why an attribute is not being inlined as a constant, you can run
94        `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
95        attribute is being modified.
96
97    Note:
98        Because freezing makes weights constants and removes module hierarchy, `to` and other
99        nn.Module methods to manipulate device or dtype no longer work. As a workaround,
100        You can remap devices by specifying `map_location` in `torch.jit.load`, however
101        device-specific logic may have been baked into the model.
102    """
103    if not isinstance(mod, ScriptModule):
104        raise RuntimeError(
105            "Freezing expects a ScriptModule as input. "
106            "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
107        )
108
109    if mod.training:
110        raise RuntimeError(
111            "Freezing is currently only implemented for modules in eval mode. "
112            "Please call .eval() on your module before freezing."
113        )
114
115    preserved_attrs = preserved_attrs if preserved_attrs is not None else []
116
117    out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
118    RecursiveScriptModule._finalize_scriptmodule(out)
119
120    preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
121    run_frozen_optimizations(out, optimize_numerics, preserved_methods)
122
123    return out
124
125
126def run_frozen_optimizations(
127    mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None
128):
129    r"""
130    Run a series of optimizations looking for patterns that occur in frozen graphs.
131
132    The current set of optimizations includes:
133        - Dropout Removal
134        - Pretranspose Linear Layers
135        - Concat Linear Layers with same input Tensor
136        - Conv -> Batchnorm folding
137        - Conv -> Add/Sub folding
138        - Conv -> Mul/Div folding
139
140    Args:
141        mod (:class:`ScriptModule`): a frozen module to be optimized
142
143        optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
144        preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close`
145        when applied on a single transformation, however in a module where many transformations are applied
146        the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding,
147        Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
148
149    Returns:
150        None
151
152    Note:
153        In rare occassions, this can result in slower execution.
154
155    Example (Freezing a module with Conv->Batchnorm)
156    .. code-block:: python
157        import torch
158        in_channels, out_channels = 3, 32
159        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
160        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
161        mod = torch.nn.Sequential(conv, bn)
162        # set optimize to False here, by default freezing runs run_frozen_optimizations
163        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
164        # inspect frozen mod
165        assert "batch_norm" in str(frozen_mod.graph)
166        torch.jit.run_frozen_optimizations(frozen_mod)
167        assert "batch_norm" not in str(frozen_mod.graph)
168
169    """
170    if mod._c._has_method("forward"):
171        torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
172
173    if preserved_methods is None:
174        preserved_methods = []
175
176    for method in preserved_methods:
177        torch._C._jit_pass_optimize_frozen_graph(
178            mod.__getattr__(method).graph, optimize_numerics
179        )
180
181
182def optimize_for_inference(
183    mod: ScriptModule, other_methods: Optional[List[str]] = None
184) -> ScriptModule:
185    """
186    Perform a set of optimization passes to optimize a model for the purposes of inference.
187
188    If the model is not already frozen, optimize_for_inference
189    will invoke `torch.jit.freeze` automatically.
190
191    In addition to generic optimizations that should speed up your model regardless
192    of environment, prepare for inference will also bake in build specific settings
193    such as the presence of CUDNN or MKLDNN, and may in the future make transformations
194    which speed things up on one machine but slow things down on another. Accordingly,
195    serialization is not implemented following invoking `optimize_for_inference` and
196    is not guaranteed.
197
198    This is still in prototype, and may have the potential to slow down your model.
199    Primary use cases that have been targeted so far have been vision models on cpu
200    and gpu to a lesser extent.
201
202    Example (optimizing a module with Conv->Batchnorm)::
203
204        import torch
205        in_channels, out_channels = 3, 32
206        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
207        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
208        mod = torch.nn.Sequential(conv, bn)
209        frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
210        assert "batch_norm" not in str(frozen_mod.graph)
211        # if built with MKLDNN, convolution will be run with MKLDNN weights
212        assert "MKLDNN" in frozen_mod.graph
213    """
214    if not isinstance(mod, ScriptModule):
215        raise RuntimeError(
216            "optimize_for_inference expects a ScriptModule as input. "
217            "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
218        )
219
220    if other_methods is None:
221        other_methods = []
222
223    if hasattr(mod, "training"):
224        mod = freeze(mod.eval(), preserved_attrs=other_methods)
225
226    torch._C._jit_pass_optimize_for_inference(mod._c, other_methods)
227
228    return mod
229