1# mypy: allow-untyped-defs 2 3from torch import nn 4 5 6class QuantStub(nn.Module): 7 r"""Quantize stub module, before calibration, this is same as an observer, 8 it will be swapped as `nnq.Quantize` in `convert`. 9 10 Args: 11 qconfig: quantization configuration for the tensor, 12 if qconfig is not provided, we will get qconfig from parent modules 13 """ 14 15 def __init__(self, qconfig=None): 16 super().__init__() 17 if qconfig: 18 self.qconfig = qconfig 19 20 def forward(self, x): 21 return x 22 23 24class DeQuantStub(nn.Module): 25 r"""Dequantize stub module, before calibration, this is same as identity, 26 this will be swapped as `nnq.DeQuantize` in `convert`. 27 28 Args: 29 qconfig: quantization configuration for the tensor, 30 if qconfig is not provided, we will get qconfig from parent modules 31 """ 32 33 def __init__(self, qconfig=None): 34 super().__init__() 35 if qconfig: 36 self.qconfig = qconfig 37 38 def forward(self, x): 39 return x 40 41 42class QuantWrapper(nn.Module): 43 r"""A wrapper class that wraps the input module, adds QuantStub and 44 DeQuantStub and surround the call to module with call to quant and dequant 45 modules. 46 47 This is used by the `quantization` utility functions to add the quant and 48 dequant modules, before `convert` function `QuantStub` will just be observer, 49 it observes the input tensor, after `convert`, `QuantStub` 50 will be swapped to `nnq.Quantize` which does actual quantization. Similarly 51 for `DeQuantStub`. 52 """ 53 quant: QuantStub 54 dequant: DeQuantStub 55 module: nn.Module 56 57 def __init__(self, module): 58 super().__init__() 59 qconfig = getattr(module, "qconfig", None) 60 self.add_module("quant", QuantStub(qconfig)) 61 self.add_module("dequant", DeQuantStub(qconfig)) 62 self.add_module("module", module) 63 self.train(module.training) 64 65 def forward(self, X): 66 X = self.quant(X) 67 X = self.module(X) 68 return self.dequant(X) 69