xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/reference/modules/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, Optional, Tuple
3
4import torch
5import torch.nn as nn
6from torch import _VF, Tensor
7from torch.nn.utils.rnn import PackedSequence
8
9from .utils import _quantize_and_dequantize_weight, _quantize_weight
10
11
12__all__ = [
13    "RNNCellBase",
14    "RNNCell",
15    "LSTMCell",
16    "GRUCell",
17    "RNNBase",
18    "LSTM",
19    "GRU",
20    "get_quantized_weight",
21]
22
23
24def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
25    return tensor.index_select(dim, permutation)
26
27
28def _get_weight_and_quantization_params(module, wn):
29    weight = getattr(module, wn)
30    params = [weight]
31    for param_name in [
32        wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]
33    ]:
34        if hasattr(module, param_name):
35            param = getattr(module, param_name)
36        else:
37            param = None
38        params.append(param)
39    return params
40
41
42def get_quantized_weight(module, wn):
43    if not hasattr(module, wn):
44        return None
45    params = _get_weight_and_quantization_params(module, wn)
46    weight = _quantize_weight(*params)
47    return weight
48
49
50def _get_quantize_and_dequantized_weight(module, wn):
51    if not hasattr(module, wn):
52        return None
53    params = _get_weight_and_quantization_params(module, wn)
54    weight = _quantize_and_dequantize_weight(*params)
55    return weight
56
57
58class RNNCellBase(nn.RNNCellBase):
59    def __init__(
60        self,
61        input_size: int,
62        hidden_size: int,
63        bias: bool,
64        num_chunks: int,
65        device=None,
66        dtype=None,
67        weight_qparams_dict=None,
68    ) -> None:
69        super().__init__(
70            input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype
71        )
72        # TODO(jerryzh168): maybe make this arg a required arg
73        if weight_qparams_dict is None:
74            weight_qparams = {
75                "qscheme": torch.per_tensor_affine,
76                "dtype": torch.quint8,
77                "scale": 1.0,
78                "zero_point": 0,
79            }
80            weight_qparams_dict = {
81                "weight_ih": weight_qparams,
82                "weight_hh": weight_qparams,
83                "is_decomposed": False,
84            }
85        assert (
86            len(weight_qparams_dict) == 3
87        ), "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
88        self._init_weight_qparams_dict(weight_qparams_dict, device)
89
90    def _init_weight_qparams_dict(self, weight_qparams_dict, device):
91        assert weight_qparams_dict is not None
92        self.is_decomposed = weight_qparams_dict["is_decomposed"]
93        for key, weight_qparams in weight_qparams_dict.items():
94            if key == "is_decomposed":
95                continue
96            # TODO: refactor the duplicated code to utils.py
97            weight_qscheme = weight_qparams["qscheme"]
98            weight_dtype = weight_qparams["dtype"]
99            setattr(self, key + "_qscheme", weight_qscheme)
100            setattr(self, key + "_dtype", weight_dtype)
101            assert weight_qscheme in [
102                None,
103                torch.per_tensor_affine,
104                torch.per_channel_affine,
105            ], Exception(
106                f"qscheme: {weight_qscheme} is not support in {self._get_name()}"
107            )
108            if weight_qscheme is not None:
109                scale = weight_qparams["scale"]
110                scale_tensor = (
111                    scale.clone().detach()
112                    if isinstance(scale, torch.Tensor)
113                    else torch.tensor(scale, dtype=torch.float, device=device)
114                )
115                self.register_buffer(key + "_scale", scale_tensor)
116                zp = weight_qparams["zero_point"]
117                zp_tensor = (
118                    zp.clone().detach()
119                    if isinstance(zp, torch.Tensor)
120                    else torch.tensor(zp, dtype=torch.int, device=device)
121                )
122                self.register_buffer(key + "_zero_point", zp_tensor)
123                if weight_qscheme == torch.per_channel_affine:
124                    axis = weight_qparams["axis"]
125                    axis_tensor = (
126                        axis.clone().detach()
127                        if isinstance(axis, torch.Tensor)
128                        else torch.tensor(axis, dtype=torch.int, device=device)
129                    )
130                    self.register_buffer(key + "_axis", axis_tensor)
131                else:
132                    # added for TorchScriptability, not used
133                    self.register_buffer(
134                        key + "_axis", torch.tensor(0, dtype=torch.int, device=device)
135                    )
136                setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
137
138    def _get_name(self):
139        return "QuantizedRNNCellBase(Reference)"
140
141    def get_quantized_weight_ih(self):
142        return get_quantized_weight(self, "weight_ih")
143
144    def get_quantized_weight_hh(self):
145        return get_quantized_weight(self, "weight_hh")
146
147    def get_weight_ih(self):
148        return _get_quantize_and_dequantized_weight(self, "weight_ih")
149
150    def get_weight_hh(self):
151        return _get_quantize_and_dequantized_weight(self, "weight_hh")
152
153
154class RNNCell(RNNCellBase):
155    """
156    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
157    we need to pass in a `weight_qparams_dict` that maps from weight name,
158    e.g. weight_ih, to the weight_qparams for that weight
159    """
160
161    def __init__(
162        self,
163        input_size: int,
164        hidden_size: int,
165        bias: bool = True,
166        nonlinearity: str = "tanh",
167        device=None,
168        dtype=None,
169        weight_qparams_dict: Optional[Dict[str, Any]] = None,
170    ) -> None:
171        factory_kwargs = {
172            "device": device,
173            "dtype": dtype,
174            "weight_qparams_dict": weight_qparams_dict,
175        }
176        super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
177        self.nonlinearity = nonlinearity
178
179    def _get_name(self):
180        return "QuantizedRNNCell(Reference)"
181
182    # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
183    # and remove duplicated code, same for the other two Cell modules
184    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
185        assert input.dim() in (
186            1,
187            2,
188        ), f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
189        is_batched = input.dim() == 2
190        if not is_batched:
191            input = input.unsqueeze(0)
192
193        if hx is None:
194            hx = torch.zeros(
195                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
196            )
197        else:
198            hx = hx.unsqueeze(0) if not is_batched else hx
199
200        if self.nonlinearity == "tanh":
201            ret = _VF.rnn_tanh_cell(
202                input,
203                hx,
204                self.get_weight_ih(),
205                self.get_weight_hh(),
206                self.bias_ih,
207                self.bias_hh,
208            )
209        elif self.nonlinearity == "relu":
210            ret = _VF.rnn_relu_cell(
211                input,
212                hx,
213                self.get_weight_ih(),
214                self.get_weight_hh(),
215                self.bias_ih,
216                self.bias_hh,
217            )
218        else:
219            ret = input  # TODO: remove when jit supports exception flow
220            raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
221
222        if not is_batched:
223            ret = ret.squeeze(0)
224
225        return ret
226
227    @classmethod
228    def from_float(cls, mod, weight_qparams_dict):
229        ref_mod = cls(
230            mod.input_size,
231            mod.hidden_size,
232            mod.bias,
233            mod.nonlinearity,
234            mod.weight_ih.device,
235            mod.weight_ih.dtype,
236            weight_qparams_dict,
237        )
238        ref_mod.weight_ih = mod.weight_ih
239        ref_mod.weight_hh = mod.weight_hh
240        ref_mod.bias_ih = mod.bias_ih
241        ref_mod.bias_hh = mod.bias_hh
242        return ref_mod
243
244
245class LSTMCell(RNNCellBase):
246    """
247    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
248    we need to pass in a `weight_qparams_dict` that maps from weight name,
249    e.g. weight_ih, to the weight_qparams for that weight
250    """
251
252    def __init__(
253        self,
254        input_size: int,
255        hidden_size: int,
256        bias: bool = True,
257        device=None,
258        dtype=None,
259        weight_qparams_dict: Optional[Dict[str, Any]] = None,
260    ) -> None:
261        factory_kwargs = {
262            "device": device,
263            "dtype": dtype,
264            "weight_qparams_dict": weight_qparams_dict,
265        }
266        super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
267
268    def _get_name(self):
269        return "QuantizedLSTMCell(Reference)"
270
271    def forward(
272        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
273    ) -> Tuple[Tensor, Tensor]:
274        assert input.dim() in (
275            1,
276            2,
277        ), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
278        is_batched = input.dim() == 2
279        if not is_batched:
280            input = input.unsqueeze(0)
281
282        if hx is None:
283            zeros = torch.zeros(
284                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
285            )
286            hx = (zeros, zeros)
287        else:
288            hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
289
290        ret = _VF.lstm_cell(
291            input,
292            hx,
293            self.get_weight_ih(),
294            self.get_weight_hh(),
295            self.bias_ih,
296            self.bias_hh,
297        )
298
299        if not is_batched:
300            ret = (ret[0].squeeze(0), ret[1].squeeze(0))
301        return ret
302
303    @classmethod
304    def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False):
305        ref_mod = cls(
306            mod.input_size,
307            mod.hidden_size,
308            mod.bias,
309            mod.weight_ih.device,
310            mod.weight_ih.dtype,
311            weight_qparams_dict,
312        )
313        ref_mod.weight_ih = mod.weight_ih
314        ref_mod.weight_hh = mod.weight_hh
315        ref_mod.bias_ih = mod.bias_ih
316        ref_mod.bias_hh = mod.bias_hh
317        return ref_mod
318
319
320class GRUCell(RNNCellBase):
321    """
322    We'll store weight_qparams for all the weights (weight_ih and weight_hh),
323    we need to pass in a `weight_qparams_dict` that maps from weight name,
324    e.g. weight_ih, to the weight_qparams for that weight
325    """
326
327    def __init__(
328        self,
329        input_size: int,
330        hidden_size: int,
331        bias: bool = True,
332        device=None,
333        dtype=None,
334        weight_qparams_dict: Optional[Dict[str, Any]] = None,
335    ) -> None:
336        factory_kwargs = {
337            "device": device,
338            "dtype": dtype,
339            "weight_qparams_dict": weight_qparams_dict,
340        }
341        super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
342
343    def _get_name(self):
344        return "QuantizedGRUCell(Reference)"
345
346    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
347        assert input.dim() in (
348            1,
349            2,
350        ), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
351        is_batched = input.dim() == 2
352        if not is_batched:
353            input = input.unsqueeze(0)
354
355        if hx is None:
356            hx = torch.zeros(
357                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
358            )
359        else:
360            hx = hx.unsqueeze(0) if not is_batched else hx
361
362        ret = _VF.gru_cell(
363            input,
364            hx,
365            self.get_weight_ih(),
366            self.get_weight_hh(),
367            self.bias_ih,
368            self.bias_hh,
369        )
370
371        if not is_batched:
372            ret = ret.squeeze(0)
373
374        return ret
375
376    @classmethod
377    def from_float(cls, mod, weight_qparams_dict):
378        ref_mod = cls(
379            mod.input_size,
380            mod.hidden_size,
381            mod.bias,
382            mod.weight_ih.device,
383            mod.weight_ih.dtype,
384            weight_qparams_dict,
385        )
386        ref_mod.weight_ih = mod.weight_ih
387        ref_mod.weight_hh = mod.weight_hh
388        ref_mod.bias_ih = mod.bias_ih
389        ref_mod.bias_hh = mod.bias_hh
390        return ref_mod
391
392
393class RNNBase(nn.RNNBase):
394    def __init__(
395        self,
396        mode: str,
397        input_size: int,
398        hidden_size: int,
399        num_layers: int = 1,
400        bias: bool = True,
401        batch_first: bool = False,
402        dropout: float = 0.0,
403        bidirectional: bool = False,
404        proj_size: int = 0,
405        device=None,
406        dtype=None,
407        weight_qparams_dict: Optional[Dict[str, Any]] = None,
408    ) -> None:
409        super().__init__(
410            mode,
411            input_size,
412            hidden_size,
413            num_layers,
414            bias,
415            batch_first,
416            dropout,
417            bidirectional,
418            proj_size,
419            device,
420            dtype,
421        )
422        # TODO(jerryzh168): maybe make this arg a required arg
423        if weight_qparams_dict is None:
424            weight_qparams = {
425                "qscheme": torch.per_tensor_affine,
426                "dtype": torch.quint8,
427                "scale": 1.0,
428                "zero_point": 0,
429            }
430            weight_qparams_dict = {"is_decomposed": False}  # type: ignore[dict-item]
431            for wn in self._flat_weights_names:
432                if wn.startswith("weight"):
433                    weight_qparams_dict[wn] = weight_qparams
434        self._init_weight_qparams_dict(weight_qparams_dict, device)
435
436    def _init_weight_qparams_dict(self, weight_qparams_dict, device):
437        self.is_decomposed = weight_qparams_dict["is_decomposed"]
438        for key, weight_qparams in weight_qparams_dict.items():
439            if key == "is_decomposed":
440                continue
441            weight_qscheme = weight_qparams["qscheme"]
442            weight_dtype = weight_qparams["dtype"]
443            setattr(self, key + "_qscheme", weight_qscheme)
444            setattr(self, key + "_dtype", weight_dtype)
445            assert weight_qscheme in [
446                None,
447                torch.per_tensor_affine,
448                torch.per_channel_affine,
449            ], Exception(
450                f"qscheme: {weight_qscheme} is not support in {self._get_name()}"
451            )
452            if weight_qscheme is not None:
453                self.register_buffer(
454                    key + "_scale",
455                    torch.tensor(
456                        weight_qparams["scale"], dtype=torch.float, device=device
457                    ),
458                )
459                self.register_buffer(
460                    key + "_zero_point",
461                    torch.tensor(
462                        weight_qparams["zero_point"], dtype=torch.int, device=device
463                    ),
464                )
465                if weight_qscheme == torch.per_channel_affine:
466                    self.register_buffer(
467                        key + "_axis",
468                        torch.tensor(
469                            weight_qparams["axis"], dtype=torch.int, device=device
470                        ),
471                    )
472                else:
473                    # added for TorchScriptability, not used
474                    self.register_buffer(
475                        key + "_axis", torch.tensor(0, dtype=torch.int, device=device)
476                    )
477                setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
478
479
480class LSTM(RNNBase):
481    """Reference Quantized LSTM Module
482    We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
483    a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
484    to the weight_qparams for that weight
485    """
486
487    def __init__(self, *args, **kwargs):
488        super().__init__("LSTM", *args, **kwargs)
489
490    # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
491    def permute_hidden(  # type: ignore[override]
492        self,
493        hx: Tuple[Tensor, Tensor],
494        permutation: Optional[Tensor],
495    ) -> Tuple[Tensor, Tensor]:
496        if permutation is None:
497            return hx
498        return _apply_permutation(hx[0], permutation), _apply_permutation(
499            hx[1], permutation
500        )
501
502    def get_expected_cell_size(
503        self, input: Tensor, batch_sizes: Optional[Tensor]
504    ) -> Tuple[int, int, int]:
505        if batch_sizes is not None:
506            mini_batch = int(batch_sizes[0])
507        else:
508            mini_batch = input.size(0) if self.batch_first else input.size(1)
509        num_directions = 2 if self.bidirectional else 1
510        expected_hidden_size = (
511            self.num_layers * num_directions,
512            mini_batch,
513            self.hidden_size,
514        )
515        return expected_hidden_size
516
517    # In the future, we should prevent mypy from applying contravariance rules here.
518    # See torch/nn/modules/module.py::_forward_unimplemented
519    def check_forward_args(  # type: ignore[override]
520        self,
521        input: Tensor,
522        hidden: Tuple[Tensor, Tensor],
523        batch_sizes: Optional[Tensor],
524    ):
525        self.check_input(input, batch_sizes)
526        self.check_hidden_size(
527            hidden[0],
528            self.get_expected_hidden_size(input, batch_sizes),
529            "Expected hidden[0] size {}, got {}",
530        )
531        self.check_hidden_size(
532            hidden[1],
533            self.get_expected_cell_size(input, batch_sizes),
534            "Expected hidden[1] size {}, got {}",
535        )
536
537    def get_quantized_weight_bias_dict(self):
538        """dictionary from flat_weight_name to quantized weight or (unquantized) bias
539        e.g.
540        {
541          "weight_ih_l0": quantized_weight,
542          "bias_ih_l0": unquantized_bias,
543          ...
544        }
545        """
546        quantized_weight_bias_dict = {}
547        for wn in self._flat_weights_names:
548            if hasattr(self, wn):
549                if wn.startswith("weight"):
550                    weight_or_bias = get_quantized_weight(self, wn)
551                else:
552                    weight_or_bias = getattr(self, wn)
553            else:
554                weight_or_bias = None
555            quantized_weight_bias_dict[wn] = weight_or_bias
556        return quantized_weight_bias_dict
557
558    def get_flat_weights(self):
559        flat_weights = []
560        for wn in self._flat_weights_names:
561            if hasattr(self, wn):
562                weight = getattr(self, wn)
563                if wn.startswith("weight"):
564                    params = _get_weight_and_quantization_params(self, wn)
565                    weight = _quantize_and_dequantize_weight(*params)
566            else:
567                weight = None
568            flat_weights.append(weight)
569        return flat_weights
570
571    def forward(self, input, hx=None):  # noqa: F811
572        orig_input = input
573        # xxx: isinstance check needs to be in conditional for TorchScript to compile
574        batch_sizes = None
575        if isinstance(orig_input, PackedSequence):
576            input, batch_sizes, sorted_indices, unsorted_indices = input
577            max_batch_size = int(batch_sizes[0])
578        else:
579            batch_sizes = None
580            is_batched = input.dim() == 3
581            batch_dim = 0 if self.batch_first else 1
582            if not is_batched:
583                input = input.unsqueeze(batch_dim)
584            max_batch_size = input.size(0) if self.batch_first else input.size(1)
585            sorted_indices = None
586            unsorted_indices = None
587
588        if hx is None:
589            num_directions = 2 if self.bidirectional else 1
590            real_hidden_size = (
591                self.proj_size if self.proj_size > 0 else self.hidden_size
592            )
593            h_zeros = torch.zeros(
594                self.num_layers * num_directions,
595                max_batch_size,
596                real_hidden_size,
597                dtype=input.dtype,
598                device=input.device,
599            )
600            c_zeros = torch.zeros(
601                self.num_layers * num_directions,
602                max_batch_size,
603                self.hidden_size,
604                dtype=input.dtype,
605                device=input.device,
606            )
607            hx = (h_zeros, c_zeros)
608        else:
609            if batch_sizes is None:  # If not PackedSequence input.
610                if is_batched:  # type: ignore[possibly-undefined]
611                    if hx[0].dim() != 3 or hx[1].dim() != 3:
612                        msg = (
613                            "For batched 3-D input, hx and cx should "
614                            f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
615                        )
616                        raise RuntimeError(msg)
617                else:
618                    if hx[0].dim() != 2 or hx[1].dim() != 2:
619                        msg = (
620                            "For unbatched 2-D input, hx and cx should "
621                            f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
622                        )
623                        raise RuntimeError(msg)
624                    hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
625
626            # Each batch of the hidden state should match the input sequence that
627            # the user believes he/she is passing in.
628            hx = self.permute_hidden(hx, sorted_indices)
629
630        self.check_forward_args(input, hx, batch_sizes)
631        if batch_sizes is None:
632            result = _VF.lstm(
633                input,
634                hx,
635                self.get_flat_weights(),
636                self.bias,
637                self.num_layers,
638                self.dropout,
639                self.training,
640                self.bidirectional,
641                self.batch_first,
642            )
643        else:
644            result = _VF.lstm(
645                input,
646                batch_sizes,
647                hx,
648                self.get_flat_weights(),
649                self.bias,
650                self.num_layers,
651                self.dropout,
652                self.training,
653                self.bidirectional,
654            )
655        output = result[0]
656        hidden = result[1:]
657        # xxx: isinstance check needs to be in conditional for TorchScript to compile
658        if isinstance(orig_input, PackedSequence):
659            output_packed = PackedSequence(
660                output, batch_sizes, sorted_indices, unsorted_indices
661            )
662            return output_packed, self.permute_hidden(hidden, unsorted_indices)
663        else:
664            if not is_batched:  # type: ignore[possibly-undefined]
665                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
666                hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
667            return output, self.permute_hidden(hidden, unsorted_indices)
668
669    def _get_name(self):
670        return "QuantizedLSTM(Reference)"
671
672    @classmethod
673    def from_float(cls, mod, weight_qparams_dict):
674        ref_mod = cls(
675            mod.input_size,
676            mod.hidden_size,
677            mod.num_layers,
678            mod.bias,
679            mod.batch_first,
680            mod.dropout,
681            mod.bidirectional,
682            weight_qparams_dict=weight_qparams_dict,
683        )
684        for wn in mod._flat_weights_names:
685            setattr(ref_mod, wn, getattr(mod, wn))
686        return ref_mod
687
688
689class GRU(RNNBase):
690    """Reference Quantized GRU Module
691    We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
692    a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
693    to the weight_qparams for that weight
694    """
695
696    def __init__(self, *args, **kwargs):
697        if "proj_size" in kwargs:
698            raise ValueError(
699                "proj_size argument is only supported for LSTM, not RNN or GRU"
700            )
701        super().__init__("GRU", *args, **kwargs)
702
703    def get_quantized_weight_bias_dict(self):
704        """dictionary from flat_weight_name to quantized weight or (unquantized) bias
705        e.g.
706        {
707          "weight_ih_l0": quantized_weight,
708          "bias_ih_l0": unquantized_bias,
709          ...
710        }
711        """
712        quantized_weight_bias_dict = {}
713        for wn in self._flat_weights_names:
714            if hasattr(self, wn):
715                if wn.startswith("weight"):
716                    weight_or_bias = get_quantized_weight(self, wn)
717                else:
718                    weight_or_bias = getattr(self, wn)
719            else:
720                weight_or_bias = None
721            quantized_weight_bias_dict[wn] = weight_or_bias
722        return quantized_weight_bias_dict
723
724    def get_flat_weights(self):
725        flat_weights = []
726        for wn in self._flat_weights_names:
727            if hasattr(self, wn):
728                weight = getattr(self, wn)
729                if wn.startswith("weight"):
730                    params = _get_weight_and_quantization_params(self, wn)
731                    weight = _quantize_and_dequantize_weight(*params)
732            else:
733                weight = None
734            flat_weights.append(weight)
735        return flat_weights
736
737    def forward(self, input, hx=None):  # noqa: F811
738        # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
739        # only changed self._flat_weights to self.get_flat_weights()
740        # TODO: maybe we can try inheriting from that class and define get_flat_weights
741        # as a @property? this might interfere with TorchScript, if we remove that
742        # requirement in the future we should be able to do this
743        orig_input = input
744        # xxx: isinstance check needs to be in conditional for TorchScript to compile
745        if isinstance(orig_input, PackedSequence):
746            input, batch_sizes, sorted_indices, unsorted_indices = input
747            max_batch_size = int(batch_sizes[0])
748        else:
749            batch_sizes = None
750            assert input.dim() in (
751                2,
752                3,
753            ), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
754            is_batched = input.dim() == 3
755            batch_dim = 0 if self.batch_first else 1
756            if not is_batched:
757                input = input.unsqueeze(batch_dim)
758                if hx is not None:
759                    if hx.dim() != 2:
760                        raise RuntimeError(
761                            f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
762                        )
763                    hx = hx.unsqueeze(1)
764            else:
765                if hx is not None and hx.dim() != 3:
766                    raise RuntimeError(
767                        f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
768                    )
769            max_batch_size = input.size(0) if self.batch_first else input.size(1)
770            sorted_indices = None
771            unsorted_indices = None
772
773        if hx is None:
774            num_directions = 2 if self.bidirectional else 1
775            hx = torch.zeros(
776                self.num_layers * num_directions,
777                max_batch_size,
778                self.hidden_size,
779                dtype=input.dtype,
780                device=input.device,
781            )
782        else:
783            # Each batch of the hidden state should match the input sequence that
784            # the user believes he/she is passing in.
785            hx = self.permute_hidden(hx, sorted_indices)
786
787        self.check_forward_args(input, hx, batch_sizes)
788        if batch_sizes is None:
789            result = _VF.gru(
790                input,
791                hx,
792                self.get_flat_weights(),
793                self.bias,
794                self.num_layers,
795                self.dropout,
796                self.training,
797                self.bidirectional,
798                self.batch_first,
799            )
800        else:
801            result = _VF.gru(
802                input,
803                batch_sizes,
804                hx,
805                self.get_flat_weights(),
806                self.bias,
807                self.num_layers,
808                self.dropout,
809                self.training,
810                self.bidirectional,
811            )
812        output = result[0]
813        hidden = result[1]
814
815        # xxx: isinstance check needs to be in conditional for TorchScript to compile
816        if isinstance(orig_input, PackedSequence):
817            output_packed = PackedSequence(
818                output, batch_sizes, sorted_indices, unsorted_indices
819            )
820            return output_packed, self.permute_hidden(hidden, unsorted_indices)
821        else:
822            if not is_batched:  # type: ignore[possibly-undefined]
823                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
824                hidden = hidden.squeeze(1)
825
826            return output, self.permute_hidden(hidden, unsorted_indices)
827
828    def _get_name(self):
829        return "QuantizedGRU(Reference)"
830
831    @classmethod
832    def from_float(cls, mod, weight_qparams_dict):
833        ref_mod = cls(
834            mod.input_size,
835            mod.hidden_size,
836            mod.num_layers,
837            mod.bias,
838            mod.batch_first,
839            mod.dropout,
840            mod.bidirectional,
841            weight_qparams_dict=weight_qparams_dict,
842        )
843        for wn in mod._flat_weights_names:
844            setattr(ref_mod, wn, getattr(mod, wn))
845        return ref_mod
846