1# mypy: allow-untyped-defs 2"""Freezing. 3 4This is not intended to be imported directly; please use the exposed 5functionalities in `torch.jit`. 6""" 7 8from typing import List, Optional 9 10import torch 11from torch.jit._script import RecursiveScriptModule, ScriptModule 12 13 14def freeze( 15 mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True 16): 17 r"""Freeze ScriptModule, inline submodules, and attributes as constants. 18 19 Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned 20 module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. 21 By default, `forward` will be preserved, as well as attributes & methods specified in 22 `preserved_attrs`. Additionally, any attribute that is modified within a preserved 23 method will be preserved. 24 25 Freezing currently only accepts ScriptModules that are in eval mode. 26 27 Freezing applies generic optimization that will speed up your model regardless of machine. 28 To further optimize using server-specific settings, run `optimize_for_inference` after 29 freezing. 30 31 Args: 32 mod (:class:`ScriptModule`): a module to be frozen 33 preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. 34 Attributes modified in preserved methods will also be preserved. 35 optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly 36 preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`. 37 38 Returns: 39 Frozen :class:`ScriptModule`. 40 41 Example (Freezing a simple module with a Parameter): 42 43 .. testcode:: 44 import torch 45 class MyModule(torch.nn.Module): 46 def __init__(self, N, M): 47 super().__init__() 48 self.weight = torch.nn.Parameter(torch.rand(N, M)) 49 self.linear = torch.nn.Linear(N, M) 50 51 def forward(self, input): 52 output = self.weight.mm(input) 53 output = self.linear(output) 54 return output 55 56 scripted_module = torch.jit.script(MyModule(2, 3).eval()) 57 frozen_module = torch.jit.freeze(scripted_module) 58 # parameters have been removed and inlined into the Graph as constants 59 assert len(list(frozen_module.named_parameters())) == 0 60 # See the compiled graph as Python code 61 print(frozen_module.code) 62 63 Example (Freezing a module with preserved attributes) 64 65 .. testcode:: 66 import torch 67 class MyModule2(torch.nn.Module): 68 def __init__(self) -> None: 69 super().__init__() 70 self.modified_tensor = torch.tensor(10.) 71 self.version = 1 72 73 def forward(self, input): 74 self.modified_tensor += 1 75 return input + self.modified_tensor 76 77 scripted_module = torch.jit.script(MyModule2().eval()) 78 frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) 79 # we've manually preserved `version`, so it still exists on the frozen module and can be modified 80 assert frozen_module.version == 1 81 frozen_module.version = 2 82 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves 83 # it to retain model semantics 84 assert frozen_module(torch.tensor(1)) == torch.tensor(12) 85 # now that we've run it once, the next result will be incremented by one 86 assert frozen_module(torch.tensor(1)) == torch.tensor(13) 87 88 Note: 89 Freezing submodule attributes is also supported: 90 frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"]) 91 92 Note: 93 If you're not sure why an attribute is not being inlined as a constant, you can run 94 `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the 95 attribute is being modified. 96 97 Note: 98 Because freezing makes weights constants and removes module hierarchy, `to` and other 99 nn.Module methods to manipulate device or dtype no longer work. As a workaround, 100 You can remap devices by specifying `map_location` in `torch.jit.load`, however 101 device-specific logic may have been baked into the model. 102 """ 103 if not isinstance(mod, ScriptModule): 104 raise RuntimeError( 105 "Freezing expects a ScriptModule as input. " 106 "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." 107 ) 108 109 if mod.training: 110 raise RuntimeError( 111 "Freezing is currently only implemented for modules in eval mode. " 112 "Please call .eval() on your module before freezing." 113 ) 114 115 preserved_attrs = preserved_attrs if preserved_attrs is not None else [] 116 117 out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) 118 RecursiveScriptModule._finalize_scriptmodule(out) 119 120 preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)] 121 run_frozen_optimizations(out, optimize_numerics, preserved_methods) 122 123 return out 124 125 126def run_frozen_optimizations( 127 mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None 128): 129 r""" 130 Run a series of optimizations looking for patterns that occur in frozen graphs. 131 132 The current set of optimizations includes: 133 - Dropout Removal 134 - Pretranspose Linear Layers 135 - Concat Linear Layers with same input Tensor 136 - Conv -> Batchnorm folding 137 - Conv -> Add/Sub folding 138 - Conv -> Mul/Div folding 139 140 Args: 141 mod (:class:`ScriptModule`): a frozen module to be optimized 142 143 optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly 144 preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close` 145 when applied on a single transformation, however in a module where many transformations are applied 146 the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding, 147 Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics. 148 149 Returns: 150 None 151 152 Note: 153 In rare occassions, this can result in slower execution. 154 155 Example (Freezing a module with Conv->Batchnorm) 156 .. code-block:: python 157 import torch 158 in_channels, out_channels = 3, 32 159 conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) 160 bn = torch.nn.BatchNorm2d(out_channels, eps=.001) 161 mod = torch.nn.Sequential(conv, bn) 162 # set optimize to False here, by default freezing runs run_frozen_optimizations 163 frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) 164 # inspect frozen mod 165 assert "batch_norm" in str(frozen_mod.graph) 166 torch.jit.run_frozen_optimizations(frozen_mod) 167 assert "batch_norm" not in str(frozen_mod.graph) 168 169 """ 170 if mod._c._has_method("forward"): 171 torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics) 172 173 if preserved_methods is None: 174 preserved_methods = [] 175 176 for method in preserved_methods: 177 torch._C._jit_pass_optimize_frozen_graph( 178 mod.__getattr__(method).graph, optimize_numerics 179 ) 180 181 182def optimize_for_inference( 183 mod: ScriptModule, other_methods: Optional[List[str]] = None 184) -> ScriptModule: 185 """ 186 Perform a set of optimization passes to optimize a model for the purposes of inference. 187 188 If the model is not already frozen, optimize_for_inference 189 will invoke `torch.jit.freeze` automatically. 190 191 In addition to generic optimizations that should speed up your model regardless 192 of environment, prepare for inference will also bake in build specific settings 193 such as the presence of CUDNN or MKLDNN, and may in the future make transformations 194 which speed things up on one machine but slow things down on another. Accordingly, 195 serialization is not implemented following invoking `optimize_for_inference` and 196 is not guaranteed. 197 198 This is still in prototype, and may have the potential to slow down your model. 199 Primary use cases that have been targeted so far have been vision models on cpu 200 and gpu to a lesser extent. 201 202 Example (optimizing a module with Conv->Batchnorm):: 203 204 import torch 205 in_channels, out_channels = 3, 32 206 conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) 207 bn = torch.nn.BatchNorm2d(out_channels, eps=.001) 208 mod = torch.nn.Sequential(conv, bn) 209 frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval())) 210 assert "batch_norm" not in str(frozen_mod.graph) 211 # if built with MKLDNN, convolution will be run with MKLDNN weights 212 assert "MKLDNN" in frozen_mod.graph 213 """ 214 if not isinstance(mod, ScriptModule): 215 raise RuntimeError( 216 "optimize_for_inference expects a ScriptModule as input. " 217 "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." 218 ) 219 220 if other_methods is None: 221 other_methods = [] 222 223 if hasattr(mod, "training"): 224 mod = freeze(mod.eval(), preserved_attrs=other_methods) 225 226 torch._C._jit_pass_optimize_for_inference(mod._c, other_methods) 227 228 return mod 229