xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/adaround_optimization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import Any, Callable, List, Optional, Tuple, Type, Union
4
5import torch
6from torch.ao.quantization.experimental.adaround_fake_quantize import (
7    AdaroundFakeQuantizer,
8)
9from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss
10from torch.ao.quantization.observer import MinMaxObserver
11from torch.nn import functional as F
12from torch.nn.parallel import DataParallel
13from torch.utils.data import DataLoader, TensorDataset
14
15
16class AdaptiveRoundingOptimizer:
17    def __init__(
18        self,
19        model: Union[torch.nn.Module, torch.nn.DataParallel],
20        callback: Callable[
21            [
22                Union[torch.nn.Module, torch.nn.DataParallel],
23                Any,
24                Optional[torch.nn.Module],
25            ],
26            None,
27        ],
28        forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable],
29        data: Any,
30        observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver,
31        max_iter=10000,
32        dtype: torch.dtype = torch.qint8,
33        quant_min=-128,
34        quant_max=127,
35        qscheme: torch.qscheme = torch.per_tensor_symmetric,
36        batch_size: int = 256,
37        feed_forward_wrapper: Optional[torch.nn.Module] = None,
38    ):
39        if torch.cuda.is_available():
40            self.model = model.cuda()
41            if torch.cuda.device_count() > 1:
42                self.model = torch.nn.DataParallel(model)
43        else:
44            self.model = model
45        self.q_model = copy.deepcopy(self.model)
46        self.device = torch.device("cuda") if torch.cuda.is_available() else None
47        self.callback = callback
48        self.forward_hook_wrapper = forward_hook_wrapper
49        # TODO rather than having a data as list type or, we better pass *iterator* instead of list
50        self.data = data
51        self.batch_size = min(batch_size, len(data))
52        self.max_iter = max_iter
53        self.adaptive_round_loss_fn = AdaptiveRoundingLoss(
54            max_iter=self.max_iter, warm_start=0.2
55        )
56        self.dtype = dtype
57        self.observer = observer
58        self.quant_min = quant_min
59        self.quant_max = quant_max
60        self.qscheme = qscheme
61        self.feed_forward_wrapper = feed_forward_wrapper
62
63    def run_adaround(self) -> torch.nn.Module:
64        layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = []
65        for (name, module), q_module in zip(
66            self.model.named_modules(), self.q_model.modules()
67        ):
68            if isinstance(module, torch.nn.ReLU):
69                # Disable all inplace operations
70                module.inplace = False
71            if isinstance(q_module, torch.nn.ReLU):
72                # Disable all inplace operations
73                q_module.inplace = False
74            if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)):
75                # Knowing activation ahead-of-time would be helpful for asymmetric formulation
76                # But this is challenging in eager mode, but graph module.
77                layer_list.append((name, module, q_module))
78        print(f"Total number of layers : {len(layer_list)}")  # noqa: G004
79
80        for name, module, q_module in layer_list:
81            print(
82                f"Kick start adaptive rounding on {name} module {module}"  # noqa: G004
83            )
84            self.optimize_adaptive_rounding(
85                module,
86                q_module,
87                None,
88            )
89
90        return (
91            self.q_model.module
92            if isinstance(self.q_model, DataParallel)
93            else self.q_model
94        )
95
96    def get_data_inp_out(
97        self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any]
98    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
99        fp_out: List[torch.Tensor] = []
100        q_input: List[torch.Tensor] = []
101        fp_input: List[torch.Tensor] = []
102        fp32_fetcher: List[torch.Tensor] = []
103        quant_fetcher: List[torch.Tensor] = []
104        handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher))
105        handler2 = q_module.register_forward_hook(
106            self.forward_hook_wrapper(quant_fetcher)
107        )
108        if torch.cuda.is_available():
109            # Somehow, we need to move the model continuously
110            # Otherwise, the model will be lowered to CPU misteriously
111            self.model = self.model.cuda()
112            self.q_model = self.q_model.cuda()
113        for data_ in data:
114            with torch.no_grad():
115                self.callback(self.model, data_, self.feed_forward_wrapper)
116                self.callback(self.q_model, data_, self.feed_forward_wrapper)
117            fp32_output = fp32_fetcher[1]
118            quant_input = quant_fetcher[0]
119            fp_out.append(fp32_output)
120            q_input.append(quant_input)
121            fp_input.append(fp32_fetcher[0])
122        handler1.remove()
123        handler2.remove()
124        return q_input, fp_out, fp_input
125
126    @torch.no_grad()
127    def feed_forward(self, x, weight, module):
128        if isinstance(module, torch.nn.Conv1d):
129            out = torch.nn.functional.conv1d(
130                x,
131                weight,
132                stride=module.stride,
133                padding=module.padding,
134                dilation=module.dilation,
135                groups=module.groups,
136            )
137        elif isinstance(module, torch.nn.Linear):
138            out = torch.nn.functional.linear(
139                x,
140                weight,
141                bias=module.bias,
142            )
143        else:
144            raise NotImplementedError
145        return out
146
147    def _compute_and_display_local_losses(
148        self,
149        ada_quantizer: AdaroundFakeQuantizer,
150        q_module: torch.nn.Module,
151        q_inp: torch.Tensor,
152        fp_out: torch.Tensor,
153    ):
154        with torch.no_grad():
155            ada_quantizer.use_soft_rounding = False
156            q_w_hard_round = ada_quantizer(q_module.weight)
157            out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module)
158            ada_quantizer.use_soft_rounding = True
159            q_w_soft_round = ada_quantizer(q_module.weight)
160            out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module)
161            soft_quant_loss = F.mse_loss(out_soft_quant, fp_out)
162            hard_quant_loss = F.mse_loss(out_hard_quant, fp_out)
163            print(
164                f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}"  # noqa: G004
165            )
166
167    def optimize_adaptive_rounding(
168        self,
169        module: torch.nn.Module,
170        q_module: torch.nn.Module,
171        activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
172    ) -> None:
173        ada_quantizer = AdaroundFakeQuantizer(
174            dtype=self.dtype,
175            observer=self.observer,
176            qscheme=self.qscheme,
177            quant_min=self.quant_min,
178            quant_max=self.quant_max,
179            reduce_range=False,
180        )
181        ada_quantizer.enable_observer()
182        ada_quantizer(q_module.weight)
183        ada_quantizer.disable_observer()
184        ada_quantizer.enable_fake_quant()
185        optimizer = torch.optim.Adam([ada_quantizer.V])
186        inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data)
187
188        print("==================== Before adaround ====================")
189        assert (
190            torch.abs(out[0] - module(fp_in[0])).sum().item() == 0
191        ), "In-placed activation is detected, please do not use activation in-placed"
192        # Stack the tensors in each list into a single tensor
193        # Assuming inp and out are your lists of tensors
194        inp_tensor = torch.vstack(inp)
195        out_tensor = torch.vstack(out)
196        dataset = TensorDataset(inp_tensor, out_tensor)
197        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
198
199        self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0])
200        global_idx = 0
201        one_iter = len(out) // self.batch_size
202        for iteration in range(self.max_iter // one_iter):
203            reconstruction_loss = regularization_loss = torch.tensor(0)
204            for q_inp, fp_out in dataloader:
205                optimizer.zero_grad()
206                q_weight = ada_quantizer(q_module.weight)
207                if isinstance(module, torch.nn.Conv1d):
208                    q_out = torch.nn.functional.conv1d(
209                        q_inp,
210                        q_weight,
211                        bias=q_module.bias,
212                        stride=q_module.stride,
213                        padding=q_module.padding,
214                        dilation=q_module.dilation,
215                        groups=q_module.groups,
216                    )
217                elif isinstance(q_module, torch.nn.Linear):
218                    q_out = torch.nn.functional.linear(
219                        q_inp,
220                        q_weight,
221                        bias=q_module.bias,
222                    )
223                else:
224                    raise NotImplementedError
225                regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn(
226                    fp_out,
227                    q_out,
228                    ada_quantizer.V,
229                    curr_iter=global_idx,
230                )
231                loss = regularization_loss + reconstruction_loss
232                loss.backward()
233                optimizer.step()
234                global_idx += 1
235                if global_idx >= self.max_iter:
236                    break
237            if global_idx >= self.max_iter:
238                break
239            if iteration % 30 == 0:
240                print(
241                    f"glob iter {global_idx} regularization_loss {regularization_loss.item()} "  # noqa: G004
242                    f"reconstruction_loss {reconstruction_loss.item()}"  # noqa: G004
243                )
244        print("==================== After adaround ====================")
245        self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0])
246
247        ada_quantizer.use_soft_rounding = True
248        ada_quantizer.V.requires_grad = False
249        ada_quantizer = ada_quantizer.eval()
250        q_weight = ada_quantizer(q_module.weight)
251        # At the end of optimization, we need to copy the adarounded weight back to the original module
252        q_module.weight.data.copy_(q_weight)
253        # Eager mode requires observer to be set as "weight_fake_quant" to be parsed
254        q_module.weight_fake_quant = ada_quantizer.activation_post_process
255