1# mypy: allow-untyped-defs 2import torch 3 4 5__all__ = [ 6 "LayerNorm", 7 "GroupNorm", 8 "InstanceNorm1d", 9 "InstanceNorm2d", 10 "InstanceNorm3d", 11] 12 13 14class LayerNorm(torch.nn.LayerNorm): 15 r"""This is the quantized version of :class:`~torch.nn.LayerNorm`. 16 17 Additional args: 18 * **scale** - quantization scale of the output, type: double. 19 * **zero_point** - quantization zero point of the output, type: long. 20 21 """ 22 23 def __init__( 24 self, 25 normalized_shape, 26 weight, 27 bias, 28 scale, 29 zero_point, 30 eps=1e-5, 31 elementwise_affine=True, 32 device=None, 33 dtype=None, 34 ) -> None: 35 factory_kwargs = {"device": device, "dtype": dtype} 36 super().__init__( 37 normalized_shape, 38 eps=eps, 39 elementwise_affine=elementwise_affine, 40 **factory_kwargs, 41 ) 42 self.weight = weight 43 self.bias = bias 44 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 45 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 46 47 def forward(self, input): 48 return torch.ops.quantized.layer_norm( 49 input, 50 self.normalized_shape, 51 weight=self.weight, 52 bias=self.bias, 53 eps=self.eps, 54 output_scale=self.scale, 55 output_zero_point=self.zero_point, 56 ) 57 58 def _get_name(self): 59 return "QuantizedLayerNorm" 60 61 @classmethod 62 def from_float(cls, mod, use_precomputed_fake_quant=False): 63 scale, zero_point = mod.activation_post_process.calculate_qparams() 64 new_mod = cls( 65 mod.normalized_shape, 66 mod.weight, 67 mod.bias, 68 float(scale), 69 int(zero_point), 70 mod.eps, 71 mod.elementwise_affine, 72 ) 73 return new_mod 74 75 @classmethod 76 def from_reference(cls, mod, scale, zero_point): 77 return cls( 78 mod.normalized_shape, 79 mod.weight, 80 mod.bias, 81 float(scale), 82 int(zero_point), 83 mod.eps, 84 mod.elementwise_affine, 85 ) 86 87 88class GroupNorm(torch.nn.GroupNorm): 89 r"""This is the quantized version of :class:`~torch.nn.GroupNorm`. 90 91 Additional args: 92 * **scale** - quantization scale of the output, type: double. 93 * **zero_point** - quantization zero point of the output, type: long. 94 95 """ 96 __constants__ = ["num_groups", "num_channels", "eps", "affine"] 97 98 def __init__( 99 self, 100 num_groups, 101 num_channels, 102 weight, 103 bias, 104 scale, 105 zero_point, 106 eps=1e-5, 107 affine=True, 108 device=None, 109 dtype=None, 110 ) -> None: 111 factory_kwargs = {"device": device, "dtype": dtype} 112 super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) 113 self.weight = weight 114 self.bias = bias 115 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 116 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 117 118 def forward(self, input): 119 return torch.ops.quantized.group_norm( 120 input, 121 self.num_groups, 122 self.weight, 123 self.bias, 124 self.eps, 125 self.scale, 126 self.zero_point, 127 ) 128 129 def _get_name(self): 130 return "QuantizedGroupNorm" 131 132 @classmethod 133 def from_float(cls, mod, use_precomputed_fake_quant=False): 134 scale, zero_point = mod.activation_post_process.calculate_qparams() 135 new_mod = cls( 136 mod.num_groups, 137 mod.num_channels, 138 mod.weight, 139 mod.bias, 140 float(scale), 141 int(zero_point), 142 mod.eps, 143 mod.affine, 144 ) 145 return new_mod 146 147 148class InstanceNorm1d(torch.nn.InstanceNorm1d): 149 r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`. 150 151 Additional args: 152 * **scale** - quantization scale of the output, type: double. 153 * **zero_point** - quantization zero point of the output, type: long. 154 155 """ 156 157 def __init__( 158 self, 159 num_features, 160 weight, 161 bias, 162 scale, 163 zero_point, 164 eps=1e-5, 165 momentum=0.1, 166 affine=False, 167 track_running_stats=False, 168 device=None, 169 dtype=None, 170 ) -> None: 171 factory_kwargs = {"device": device, "dtype": dtype} 172 super().__init__( 173 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 174 ) 175 self.weight = weight 176 self.bias = bias 177 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 178 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 179 180 def forward(self, input): 181 return torch.ops.quantized.instance_norm( 182 input, self.weight, self.bias, self.eps, self.scale, self.zero_point 183 ) 184 185 def _get_name(self): 186 return "QuantizedInstanceNorm1d" 187 188 @classmethod 189 def from_float(cls, mod, use_precomputed_fake_quant=False): 190 scale, zero_point = mod.activation_post_process.calculate_qparams() 191 new_mod = cls( 192 mod.num_features, 193 mod.weight, 194 mod.bias, 195 float(scale), 196 int(zero_point), 197 mod.eps, 198 mod.affine, 199 ) 200 return new_mod 201 202 @classmethod 203 def from_reference(cls, mod, scale, zero_point): 204 return cls( 205 mod.num_features, 206 mod.weight, 207 mod.bias, 208 float(scale), 209 int(zero_point), 210 mod.eps, 211 mod.affine, 212 ) 213 214 215class InstanceNorm2d(torch.nn.InstanceNorm2d): 216 r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`. 217 218 Additional args: 219 * **scale** - quantization scale of the output, type: double. 220 * **zero_point** - quantization zero point of the output, type: long. 221 222 """ 223 224 def __init__( 225 self, 226 num_features, 227 weight, 228 bias, 229 scale, 230 zero_point, 231 eps=1e-5, 232 momentum=0.1, 233 affine=False, 234 track_running_stats=False, 235 device=None, 236 dtype=None, 237 ) -> None: 238 factory_kwargs = {"device": device, "dtype": dtype} 239 super().__init__( 240 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 241 ) 242 self.weight = weight 243 self.bias = bias 244 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 245 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 246 247 def forward(self, input): 248 return torch.ops.quantized.instance_norm( 249 input, self.weight, self.bias, self.eps, self.scale, self.zero_point 250 ) 251 252 def _get_name(self): 253 return "QuantizedInstanceNorm2d" 254 255 @classmethod 256 def from_float(cls, mod, use_precomputed_fake_quant=False): 257 scale, zero_point = mod.activation_post_process.calculate_qparams() 258 new_mod = cls( 259 mod.num_features, 260 mod.weight, 261 mod.bias, 262 float(scale), 263 int(zero_point), 264 mod.eps, 265 mod.affine, 266 ) 267 return new_mod 268 269 @classmethod 270 def from_reference(cls, mod, scale, zero_point): 271 return cls( 272 mod.num_features, 273 mod.weight, 274 mod.bias, 275 float(scale), 276 int(zero_point), 277 mod.eps, 278 mod.affine, 279 ) 280 281 282class InstanceNorm3d(torch.nn.InstanceNorm3d): 283 r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`. 284 285 Additional args: 286 * **scale** - quantization scale of the output, type: double. 287 * **zero_point** - quantization zero point of the output, type: long. 288 289 """ 290 291 def __init__( 292 self, 293 num_features, 294 weight, 295 bias, 296 scale, 297 zero_point, 298 eps=1e-5, 299 momentum=0.1, 300 affine=False, 301 track_running_stats=False, 302 device=None, 303 dtype=None, 304 ) -> None: 305 factory_kwargs = {"device": device, "dtype": dtype} 306 super().__init__( 307 num_features, eps, momentum, affine, track_running_stats, **factory_kwargs 308 ) 309 self.weight = weight 310 self.bias = bias 311 self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) 312 self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) 313 314 def forward(self, input): 315 return torch.ops.quantized.instance_norm( 316 input, self.weight, self.bias, self.eps, self.scale, self.zero_point 317 ) 318 319 def _get_name(self): 320 return "QuantizedInstanceNorm3d" 321 322 @classmethod 323 def from_float(cls, mod, use_precomputed_fake_quant=False): 324 scale, zero_point = mod.activation_post_process.calculate_qparams() 325 new_mod = cls( 326 mod.num_features, 327 mod.weight, 328 mod.bias, 329 float(scale), 330 int(zero_point), 331 mod.eps, 332 mod.affine, 333 ) 334 return new_mod 335 336 @classmethod 337 def from_reference(cls, mod, scale, zero_point): 338 return cls( 339 mod.num_features, 340 mod.weight, 341 mod.bias, 342 float(scale), 343 int(zero_point), 344 mod.eps, 345 mod.affine, 346 ) 347