1# mypy: allow-untyped-defs 2import types 3 4import torch 5import torch.nn.functional as F 6from torch.ao.quantization.utils import _assert_and_get_unique_device 7 8 9__all__ = [ 10 "model_is_exported", 11] 12 13 14class _WrapperModule(torch.nn.Module): 15 """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you 16 are trying to export a callable. 17 """ 18 19 def __init__(self, fn): 20 super().__init__() 21 self.fn = fn 22 23 def forward(self, *args, **kwargs): 24 """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" 25 return self.fn(*args, **kwargs) 26 27 28def model_is_exported(m: torch.nn.Module) -> bool: 29 """ 30 Return True if the `torch.nn.Module` was exported, False otherwise 31 (e.g. if the model was FX symbolically traced or not traced at all). 32 """ 33 return isinstance(m, torch.fx.GraphModule) and any( 34 "val" in n.meta for n in m.graph.nodes 35 ) 36 37 38def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): 39 """ 40 Switch dropout patterns in the model between train and eval modes. 41 42 Dropout has different behavior in train vs eval mode. For exported models, 43 however, calling `model.train()` or `model.eval()` does not automatically switch 44 the dropout behavior between the two modes, so here we need to rewrite the aten 45 dropout patterns manually to achieve the same effect. 46 47 See https://github.com/pytorch/pytorch/issues/103681. 48 """ 49 # Avoid circular dependencies 50 from .utils import _get_aten_graph_module_for_pattern 51 52 # Needed to ensure subgraph matches are self-contained 53 m.graph.eliminate_dead_code() 54 m.recompile() 55 56 for inplace in [False, True]: 57 58 def dropout_train(x): 59 return F.dropout(x, p=0.5, training=True, inplace=inplace) 60 61 def dropout_eval(x): 62 return F.dropout(x, p=0.5, training=False, inplace=inplace) 63 64 example_inputs = (torch.randn(1),) 65 if train_to_eval: 66 match_pattern = _get_aten_graph_module_for_pattern( 67 _WrapperModule(dropout_train), example_inputs 68 ) 69 replacement_pattern = _get_aten_graph_module_for_pattern( 70 _WrapperModule(dropout_eval), example_inputs 71 ) 72 else: 73 match_pattern = _get_aten_graph_module_for_pattern( 74 _WrapperModule(dropout_eval), example_inputs 75 ) 76 replacement_pattern = _get_aten_graph_module_for_pattern( 77 _WrapperModule(dropout_train), example_inputs 78 ) 79 80 from torch.fx.subgraph_rewriter import replace_pattern_with_filters 81 82 replace_pattern_with_filters( 83 m, 84 match_pattern, 85 replacement_pattern, 86 match_filters=[], 87 ignore_literals=True, 88 ) 89 m.recompile() 90 91 92def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): 93 """ 94 Switch batchnorm patterns in the model between train and eval modes. 95 96 Batchnorm has different behavior in train vs eval mode. For exported models, 97 however, calling `model.train()` or `model.eval()` does not automatically switch 98 the batchnorm behavior between the two modes, so here we need to rewrite the aten 99 batchnorm patterns manually to achieve the same effect. 100 """ 101 # TODO(Leslie): This function still fails to support custom momentum and eps value. 102 # Enable this support in future updates. 103 104 # Avoid circular dependencies 105 from .utils import _get_aten_graph_module_for_pattern 106 107 # Needed to ensure subgraph matches are self-contained 108 m.graph.eliminate_dead_code() 109 m.recompile() 110 111 def bn_train( 112 x: torch.Tensor, 113 bn_weight: torch.Tensor, 114 bn_bias: torch.Tensor, 115 bn_running_mean: torch.Tensor, 116 bn_running_var: torch.Tensor, 117 ): 118 return F.batch_norm( 119 x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True 120 ) 121 122 def bn_eval( 123 x: torch.Tensor, 124 bn_weight: torch.Tensor, 125 bn_bias: torch.Tensor, 126 bn_running_mean: torch.Tensor, 127 bn_running_var: torch.Tensor, 128 ): 129 return F.batch_norm( 130 x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False 131 ) 132 133 example_inputs = ( 134 torch.randn(1, 1, 3, 3), # x 135 torch.randn(1), # bn_weight 136 torch.randn(1), # bn_bias 137 torch.randn(1), # bn_running_mean 138 torch.randn(1), # bn_running_var 139 ) 140 141 device = _assert_and_get_unique_device(m) 142 is_cuda = device is not None and device.type == "cuda" 143 bn_train_aten = _get_aten_graph_module_for_pattern( 144 _WrapperModule(bn_train), 145 example_inputs, 146 is_cuda, 147 ) 148 bn_eval_aten = _get_aten_graph_module_for_pattern( 149 _WrapperModule(bn_eval), 150 example_inputs, 151 is_cuda, 152 ) 153 154 if train_to_eval: 155 match_pattern = bn_train_aten 156 replacement_pattern = bn_eval_aten 157 else: 158 match_pattern = bn_eval_aten 159 replacement_pattern = bn_train_aten 160 161 from torch.fx.subgraph_rewriter import replace_pattern_with_filters 162 163 replace_pattern_with_filters( 164 m, 165 match_pattern, 166 replacement_pattern, 167 match_filters=[], 168 ignore_literals=True, 169 ) 170 m.recompile() 171 172 173# TODO: expose these under this namespace? 174def _move_exported_model_to_eval(model: torch.fx.GraphModule): 175 """ 176 Move an exported GraphModule to eval mode. 177 178 This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. 179 QAT users should call this before performing inference on the model. 180 """ 181 _replace_dropout(model, train_to_eval=True) 182 _replace_batchnorm(model, train_to_eval=True) 183 return model 184 185 186def _move_exported_model_to_train(model: torch.fx.GraphModule): 187 """ 188 Move an exported GraphModule to train mode. 189 190 This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. 191 QAT users should call this before performing training on the model. 192 """ 193 _replace_dropout(model, train_to_eval=False) 194 _replace_batchnorm(model, train_to_eval=False) 195 return model 196 197 198def _allow_exported_model_train_eval(model: torch.fx.GraphModule): 199 """ 200 Allow users to call `model.train()` and `model.eval()` on an exported model, 201 but with the effect of changing behavior between the two modes limited to special 202 ops only, which are currently dropout and batchnorm. 203 204 Note: This does not achieve the same effect as what `model.train()` and `model.eval()` 205 does in eager models, but only provides an approximation. In particular, user code 206 branching on `training` flag will not function correctly in general because the branch 207 is already specialized at export time. Additionally, other ops beyond dropout and batchnorm 208 that have different train/eval behavior will also not be converted properly. 209 """ 210 211 def _train(self, mode: bool = True): 212 if mode: 213 _move_exported_model_to_train(self) 214 else: 215 _move_exported_model_to_eval(self) 216 217 def _eval(self): 218 _move_exported_model_to_eval(self) 219 220 model.train = types.MethodType(_train, model) # type: ignore[method-assign] 221 model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] 222 return model 223