xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/dynamic/modules/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import numbers
4import warnings
5from typing_extensions import deprecated
6
7import torch
8import torch.nn as nn
9from torch import Tensor  # noqa: F401
10from torch._jit_internal import Dict, List, Optional, Tuple, Union  # noqa: F401
11from torch.ao.nn.quantized.modules.utils import _quantize_weight
12from torch.nn.utils.rnn import PackedSequence
13
14
15__all__ = [
16    "pack_weight_bias",
17    "PackedParameter",
18    "RNNBase",
19    "LSTM",
20    "GRU",
21    "RNNCellBase",
22    "RNNCell",
23    "LSTMCell",
24    "GRUCell",
25    "apply_permutation",
26]
27
28
29def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
30    return tensor.index_select(dim, permutation)
31
32
33@deprecated(
34    "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
35    category=FutureWarning,
36)
37def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
38    return _apply_permutation(tensor, permutation, dim)
39
40
41def pack_weight_bias(qweight, bias, dtype):
42    if dtype == torch.qint8:
43        # for each layer, for each direction we need to quantize and pack
44        # weights and pack parameters in this order:
45        #
46        #   w_ih, w_hh
47        packed_weight = torch.ops.quantized.linear_prepack(qweight, bias)
48
49        return packed_weight
50    else:
51        # for each layer, for each direction we need to quantize and pack
52        # weights and pack parameters in this order:
53        #
54        #   packed_ih, packed_hh, b_ih, b_hh
55        packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias)
56
57        return packed_weight
58
59
60class PackedParameter(torch.nn.Module):
61    def __init__(self, param):
62        super().__init__()
63        self.param = param
64
65    def _save_to_state_dict(self, destination, prefix, keep_vars):
66        super()._save_to_state_dict(destination, prefix, keep_vars)
67        destination[prefix + "param"] = self.param
68
69    def _load_from_state_dict(
70        self,
71        state_dict,
72        prefix,
73        local_metadata,
74        strict,
75        missing_keys,
76        unexpected_keys,
77        error_msgs,
78    ):
79        self.param = state_dict[prefix + "param"]
80        super()._load_from_state_dict(
81            state_dict,
82            prefix,
83            local_metadata,
84            False,
85            missing_keys,
86            unexpected_keys,
87            error_msgs,
88        )
89
90
91class RNNBase(torch.nn.Module):
92    _FLOAT_MODULE = nn.RNNBase
93
94    _version = 2
95
96    def __init__(
97        self,
98        mode,
99        input_size,
100        hidden_size,
101        num_layers=1,
102        bias=True,
103        batch_first=False,
104        dropout=0.0,
105        bidirectional=False,
106        dtype=torch.qint8,
107    ):
108        super().__init__()
109
110        self.mode = mode
111        self.input_size = input_size
112        self.hidden_size = hidden_size
113        self.num_layers = num_layers
114        self.bias = bias
115        self.batch_first = batch_first
116        self.dropout = float(dropout)
117        self.bidirectional = bidirectional
118        self.dtype = dtype
119        self.version = 2
120        self.training = False
121        num_directions = 2 if bidirectional else 1
122
123        # "type: ignore" is required since ints and Numbers are not fully comparable
124        # https://github.com/python/mypy/issues/8566
125        if (
126            not isinstance(dropout, numbers.Number)
127            or not 0 <= dropout <= 1  # type: ignore[operator]
128            or isinstance(dropout, bool)
129        ):
130            raise ValueError(
131                "dropout should be a number in range [0, 1] "
132                "representing the probability of an element being "
133                "zeroed"
134            )
135        if dropout > 0 and num_layers == 1:  # type: ignore[operator]
136            warnings.warn(
137                "dropout option adds dropout after all but last "
138                "recurrent layer, so non-zero dropout expects "
139                f"num_layers greater than 1, but got dropout={dropout} and "
140                f"num_layers={num_layers}"
141            )
142
143        if mode == "LSTM":
144            gate_size = 4 * hidden_size
145        elif mode == "GRU":
146            gate_size = 3 * hidden_size
147        else:
148            raise ValueError("Unrecognized RNN mode: " + mode)
149
150        _all_weight_values = []
151        for layer in range(num_layers):
152            for direction in range(num_directions):
153                layer_input_size = (
154                    input_size if layer == 0 else hidden_size * num_directions
155                )
156
157                w_ih = torch.randn(gate_size, layer_input_size).to(torch.float)
158                w_hh = torch.randn(gate_size, hidden_size).to(torch.float)
159                b_ih = torch.randn(gate_size).to(torch.float)
160                b_hh = torch.randn(gate_size).to(torch.float)
161                if dtype == torch.qint8:
162                    w_ih = torch.quantize_per_tensor(
163                        w_ih, scale=0.1, zero_point=0, dtype=torch.qint8
164                    )
165                    w_hh = torch.quantize_per_tensor(
166                        w_hh, scale=0.1, zero_point=0, dtype=torch.qint8
167                    )
168                    packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
169                    packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
170                    if self.version is None or self.version < 2:
171                        cell_params = (
172                            torch.ops.quantized.make_quantized_cell_params_dynamic(
173                                packed_ih, packed_hh, b_ih, b_hh
174                            )
175                        )
176                    else:
177                        cell_params = (
178                            torch.ops.quantized.make_quantized_cell_params_dynamic(
179                                packed_ih, packed_hh, b_ih, b_hh, True
180                            )
181                        )
182                else:
183                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
184                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
185                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
186                        packed_ih, packed_hh
187                    )
188
189                _all_weight_values.append(PackedParameter(cell_params))
190        self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
191
192    def _get_name(self):
193        return "DynamicQuantizedRNN"
194
195    def extra_repr(self):
196        s = "{input_size}, {hidden_size}"
197        if self.num_layers != 1:
198            s += ", num_layers={num_layers}"
199        if self.bias is not True:
200            s += ", bias={bias}"
201        if self.batch_first is not False:
202            s += ", batch_first={batch_first}"
203        if self.dropout != 0:
204            s += ", dropout={dropout}"
205        if self.bidirectional is not False:
206            s += ", bidirectional={bidirectional}"
207        return s.format(**self.__dict__)
208
209    def __repr__(self):
210        # We don't want to show `ModuleList` children, hence custom
211        # `__repr__`. This is the same as nn.Module.__repr__, except the check
212        # for the `PackedParameter` and `nn.ModuleList`.
213        # You should still override `extra_repr` to add more info.
214        extra_lines = []
215        extra_repr = self.extra_repr()
216        # empty string will be split into list ['']
217        if extra_repr:
218            extra_lines = extra_repr.split("\n")
219        child_lines = []
220        for key, module in self._modules.items():
221            if isinstance(module, (PackedParameter, nn.ModuleList)):
222                continue
223            mod_str = repr(module)
224            mod_str = nn.modules.module._addindent(mod_str, 2)
225            child_lines.append("(" + key + "): " + mod_str)
226        lines = extra_lines + child_lines
227
228        main_str = self._get_name() + "("
229        if lines:
230            # simple one-liner info, which most builtin Modules will use
231            if len(extra_lines) == 1 and not child_lines:
232                main_str += extra_lines[0]
233            else:
234                main_str += "\n  " + "\n  ".join(lines) + "\n"
235
236        main_str += ")"
237        return main_str
238
239    def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
240        expected_input_dim = 2 if batch_sizes is not None else 3
241        if input.dim() != expected_input_dim:
242            raise RuntimeError(
243                f"input must have {expected_input_dim} dimensions, got {input.dim()}"
244            )
245        if self.input_size != input.size(-1):
246            raise RuntimeError(
247                f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
248            )
249
250    def get_expected_hidden_size(
251        self, input: Tensor, batch_sizes: Optional[Tensor]
252    ) -> Tuple[int, int, int]:
253        if batch_sizes is not None:
254            mini_batch = int(batch_sizes[0])
255        else:
256            mini_batch = input.size(0) if self.batch_first else input.size(1)
257        num_directions = 2 if self.bidirectional else 1
258        expected_hidden_size = (
259            self.num_layers * num_directions,
260            mini_batch,
261            self.hidden_size,
262        )
263        return expected_hidden_size
264
265    def check_hidden_size(
266        self,
267        hx: Tensor,
268        expected_hidden_size: Tuple[int, int, int],
269        msg: str = "Expected hidden size {}, got {}",
270    ) -> None:
271        if hx.size() != expected_hidden_size:
272            raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
273
274    def check_forward_args(
275        self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
276    ) -> None:
277        self.check_input(input, batch_sizes)
278        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
279        self.check_hidden_size(
280            hidden, expected_hidden_size, msg="Expected hidden size {}, got {}"
281        )
282
283    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
284        if permutation is None:
285            return hx
286        return _apply_permutation(hx, permutation)
287
288    def _load_from_state_dict(
289        self,
290        state_dict,
291        prefix,
292        local_metadata,
293        strict,
294        missing_keys,
295        unexpected_keys,
296        error_msgs,
297    ):
298        version = local_metadata.get("version", None)
299        self.version = version
300        super()._load_from_state_dict(
301            state_dict,
302            prefix,
303            local_metadata,
304            False,
305            missing_keys,
306            unexpected_keys,
307            error_msgs,
308        )
309
310    def set_weight_bias(self, weight_bias_dict):
311        def weight_bias_name(ihhh, layer, suffix):
312            weight_name = f"weight_{ihhh}_l{layer}{suffix}"
313            bias_name = f"bias_{ihhh}_l{layer}{suffix}"
314            return weight_name, bias_name
315
316        num_directions = 2 if self.bidirectional else 1
317        # TODO: dedup with __init__ of RNNBase
318        _all_weight_values = []
319        for layer in range(self.num_layers):
320            for direction in range(num_directions):
321                suffix = "_reverse" if direction == 1 else ""
322                w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix)
323                w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix)
324                w_ih = weight_bias_dict[w_ih_name]
325                b_ih = weight_bias_dict[b_ih_name]
326                w_hh = weight_bias_dict[w_hh_name]
327                b_hh = weight_bias_dict[b_hh_name]
328                if w_ih.dtype == torch.qint8:
329                    packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
330                    packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
331                    if self.version is None or self.version < 2:
332                        cell_params = (
333                            torch.ops.quantized.make_quantized_cell_params_dynamic(
334                                packed_ih, packed_hh, b_ih, b_hh
335                            )
336                        )
337                    else:
338                        cell_params = (
339                            torch.ops.quantized.make_quantized_cell_params_dynamic(
340                                packed_ih, packed_hh, b_ih, b_hh, True
341                            )
342                        )
343                else:
344                    packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
345                    packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
346                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
347                        packed_ih, packed_hh
348                    )
349
350                _all_weight_values.append(PackedParameter(cell_params))
351        self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
352
353    @classmethod
354    def from_float(cls, mod, use_precomputed_fake_quant=False):
355        assert type(mod) in {
356            torch.nn.LSTM,
357            torch.nn.GRU,
358        }, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU"
359        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
360
361        if mod.qconfig is not None and mod.qconfig.weight is not None:
362            weight_observer_method = mod.qconfig.weight
363        else:
364            # We have the circular import issues if we import the qconfig in the beginning of this file:
365            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
366            # import until we need it.
367            from torch.ao.quantization.qconfig import default_dynamic_qconfig
368
369            weight_observer_method = default_dynamic_qconfig.weight
370
371        dtype = weight_observer_method().dtype
372        supported_scalar_types = [torch.qint8, torch.float16]
373        if dtype not in supported_scalar_types:
374            raise RuntimeError(
375                f"Unsupported dtype for dynamic RNN quantization: {dtype}"
376            )
377        # RNNBase can be either LSTM or GRU
378        qRNNBase: Union[LSTM, GRU]
379        if mod.mode == "LSTM":
380            qRNNBase = LSTM(
381                mod.input_size,
382                mod.hidden_size,
383                mod.num_layers,
384                mod.bias,
385                mod.batch_first,
386                mod.dropout,
387                mod.bidirectional,
388                dtype,
389            )
390        elif mod.mode == "GRU":
391            qRNNBase = GRU(
392                mod.input_size,
393                mod.hidden_size,
394                mod.num_layers,
395                mod.bias,
396                mod.batch_first,
397                mod.dropout,
398                mod.bidirectional,
399                dtype,
400            )
401        else:
402            raise NotImplementedError(
403                "Only LSTM/GRU is supported for QuantizedRNN for now"
404            )
405
406        num_directions = 2 if mod.bidirectional else 1
407
408        assert mod.bias
409
410        _all_weight_values = []
411        for layer in range(qRNNBase.num_layers):
412            for direction in range(num_directions):
413                suffix = "_reverse" if direction == 1 else ""
414
415                def retrieve_weight_bias(ihhh):
416                    weight_name = f"weight_{ihhh}_l{layer}{suffix}"
417                    bias_name = f"bias_{ihhh}_l{layer}{suffix}"
418                    weight = getattr(mod, weight_name)
419                    bias = getattr(mod, bias_name)
420                    return weight, bias
421
422                weight_ih, bias_ih = retrieve_weight_bias("ih")
423                weight_hh, bias_hh = retrieve_weight_bias("hh")
424
425                if dtype == torch.qint8:
426
427                    def quantize_and_pack(w, b):
428                        weight_observer = weight_observer_method()
429                        weight_observer(w)
430                        qweight = _quantize_weight(w.float(), weight_observer)
431                        packed_weight = torch.ops.quantized.linear_prepack(qweight, b)
432                        return packed_weight
433
434                    packed_ih = quantize_and_pack(weight_ih, bias_ih)
435                    packed_hh = quantize_and_pack(weight_hh, bias_hh)
436                    if qRNNBase.version is None or qRNNBase.version < 2:
437                        cell_params = (
438                            torch.ops.quantized.make_quantized_cell_params_dynamic(
439                                packed_ih, packed_hh, bias_ih, bias_hh
440                            )
441                        )
442                    else:
443                        cell_params = (
444                            torch.ops.quantized.make_quantized_cell_params_dynamic(
445                                packed_ih, packed_hh, bias_ih, bias_hh, True
446                            )
447                        )
448
449                elif dtype == torch.float16:
450                    packed_ih = torch.ops.quantized.linear_prepack_fp16(
451                        weight_ih.float(), bias_ih
452                    )
453                    packed_hh = torch.ops.quantized.linear_prepack_fp16(
454                        weight_hh.float(), bias_hh
455                    )
456
457                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
458                        packed_ih, packed_hh
459                    )
460                else:
461                    raise RuntimeError(
462                        "Unsupported dtype specified for dynamic quantized LSTM!"
463                    )
464
465                _all_weight_values.append(PackedParameter(cell_params))
466        qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values)
467
468        return qRNNBase
469
470    def _weight_bias(self):
471        # Returns a dict of weights and biases
472        weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}}
473        count = 0
474        num_directions = 2 if self.bidirectional else 1
475        for layer in range(self.num_layers):
476            for direction in range(num_directions):
477                suffix = "_reverse" if direction == 1 else ""
478                key_name1 = f"weight_ih_l{layer}{suffix}"
479                key_name2 = f"weight_hh_l{layer}{suffix}"
480                # packed weights are part of torchbind class, CellParamsSerializationType
481                # Within the packed weight class, the weight and bias are accessible as Tensors
482                packed_weight_bias = self._all_weight_values[
483                    count
484                ].param.__getstate__()[0][4]
485                weight_bias_dict["weight"][key_name1] = packed_weight_bias[
486                    0
487                ].__getstate__()[0][0]
488                weight_bias_dict["weight"][key_name2] = packed_weight_bias[
489                    1
490                ].__getstate__()[0][0]
491                key_name1 = f"bias_ih_l{layer}{suffix}"
492                key_name2 = f"bias_hh_l{layer}{suffix}"
493                weight_bias_dict["bias"][key_name1] = packed_weight_bias[
494                    0
495                ].__getstate__()[0][1]
496                weight_bias_dict["bias"][key_name2] = packed_weight_bias[
497                    1
498                ].__getstate__()[0][1]
499                count = count + 1
500        return weight_bias_dict
501
502    def get_weight(self):
503        return self._weight_bias()["weight"]
504
505    def get_bias(self):
506        return self._weight_bias()["bias"]
507
508
509class LSTM(RNNBase):
510    r"""
511    A dynamic quantized LSTM module with floating point tensor as inputs and outputs.
512    We adopt the same interface as `torch.nn.LSTM`, please see
513    https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation.
514
515    Examples::
516
517        >>> # xdoctest: +SKIP
518        >>> rnn = nn.LSTM(10, 20, 2)
519        >>> input = torch.randn(5, 3, 10)
520        >>> h0 = torch.randn(2, 3, 20)
521        >>> c0 = torch.randn(2, 3, 20)
522        >>> output, (hn, cn) = rnn(input, (h0, c0))
523    """
524    _FLOAT_MODULE = nn.LSTM
525
526    __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
527
528    def __init__(self, *args, **kwargs):
529        super().__init__("LSTM", *args, **kwargs)
530
531    def _get_name(self):
532        return "DynamicQuantizedLSTM"
533
534    def forward_impl(
535        self,
536        input: Tensor,
537        hx: Optional[Tuple[Tensor, Tensor]],
538        batch_sizes: Optional[Tensor],
539        max_batch_size: int,
540        sorted_indices: Optional[Tensor],
541    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
542        if hx is None:
543            num_directions = 2 if self.bidirectional else 1
544            zeros = torch.zeros(
545                self.num_layers * num_directions,
546                max_batch_size,
547                self.hidden_size,
548                dtype=input.dtype,
549                device=input.device,
550            )
551            hx = (zeros, zeros)
552        else:
553            # Each batch of the hidden state should match the input sequence that
554            # the user believes he/she is passing in.
555            hx = self.permute_hidden(hx, sorted_indices)
556
557        self.check_forward_args(input, hx, batch_sizes)
558
559        _all_params = [m.param for m in self._all_weight_values]
560        if batch_sizes is None:
561            result = torch.quantized_lstm(
562                input,
563                hx,
564                _all_params,
565                self.bias,
566                self.num_layers,
567                float(self.dropout),
568                self.training,
569                self.bidirectional,
570                self.batch_first,
571                dtype=self.dtype,
572                use_dynamic=True,
573            )
574        else:
575            result = torch.quantized_lstm(
576                input,
577                batch_sizes,
578                hx,
579                _all_params,
580                self.bias,
581                self.num_layers,
582                float(self.dropout),
583                self.training,
584                self.bidirectional,
585                dtype=self.dtype,
586                use_dynamic=True,
587            )
588        output = result[0]
589        hidden = result[1:]
590
591        return output, hidden
592
593    @torch.jit.export
594    def forward_tensor(
595        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
596    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
597        batch_sizes = None
598        max_batch_size = input.size(0) if self.batch_first else input.size(1)
599        sorted_indices = None
600        unsorted_indices = None
601
602        output, hidden = self.forward_impl(
603            input, hx, batch_sizes, max_batch_size, sorted_indices
604        )
605
606        return output, self.permute_hidden(hidden, unsorted_indices)
607
608    @torch.jit.export
609    def forward_packed(
610        self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
611    ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
612        input_, batch_sizes, sorted_indices, unsorted_indices = input
613        max_batch_size = int(batch_sizes[0])
614
615        output_, hidden = self.forward_impl(
616            input_, hx, batch_sizes, max_batch_size, sorted_indices
617        )
618
619        output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices)
620        return output, self.permute_hidden(hidden, unsorted_indices)
621
622    # "type: ignore" is required due to issue #43072
623    def permute_hidden(  # type: ignore[override]
624        self,
625        hx: Tuple[Tensor, Tensor],
626        permutation: Optional[Tensor],
627    ) -> Tuple[Tensor, Tensor]:
628        if permutation is None:
629            return hx
630        return _apply_permutation(hx[0], permutation), _apply_permutation(
631            hx[1], permutation
632        )
633
634    # "type: ignore" is required due to issue #43072
635    def check_forward_args(  # type: ignore[override]
636        self,
637        input: Tensor,
638        hidden: Tuple[Tensor, Tensor],
639        batch_sizes: Optional[Tensor],
640    ) -> None:
641        self.check_input(input, batch_sizes)
642        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
643
644        self.check_hidden_size(
645            hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}"
646        )
647        self.check_hidden_size(
648            hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}"
649        )
650
651    @torch.jit.ignore
652    def forward(self, input, hx=None):
653        if isinstance(input, PackedSequence):
654            return self.forward_packed(input, hx)
655        else:
656            return self.forward_tensor(input, hx)
657
658    @classmethod
659    def from_float(cls, mod, use_precomputed_fake_quant=False):
660        return super().from_float(
661            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
662        )
663
664    @classmethod
665    def from_reference(cls, ref_mod):
666        assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
667        "exists in LSTM, may need to relax the assumption to support the use case"
668        qmod = cls(
669            ref_mod.input_size,
670            ref_mod.hidden_size,
671            ref_mod.num_layers,
672            ref_mod.bias,
673            ref_mod.batch_first,
674            ref_mod.dropout,
675            ref_mod.bidirectional,
676            # assuming there is layer 0, which should be OK
677            ref_mod.weight_ih_l0_dtype,
678        )
679        qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
680        return qmod
681
682
683class GRU(RNNBase):
684    r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
685
686
687    For each element in the input sequence, each layer computes the following
688    function:
689
690    .. math::
691        \begin{array}{ll}
692            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
693            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
694            n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
695            h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
696        \end{array}
697
698    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
699    at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
700    at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
701    :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
702    :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
703
704    In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
705    (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
706    dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
707    variable which is :math:`0` with probability :attr:`dropout`.
708
709    Args:
710        input_size: The number of expected features in the input `x`
711        hidden_size: The number of features in the hidden state `h`
712        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
713            would mean stacking two GRUs together to form a `stacked GRU`,
714            with the second GRU taking in outputs of the first GRU and
715            computing the final results. Default: 1
716        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
717            Default: ``True``
718        batch_first: If ``True``, then the input and output tensors are provided
719            as (batch, seq, feature). Default: ``False``
720        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
721            GRU layer except the last layer, with dropout probability equal to
722            :attr:`dropout`. Default: 0
723        bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
724
725    Inputs: input, h_0
726        - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
727          of the input sequence. The input can also be a packed variable length
728          sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
729          for details.
730        - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
731          containing the initial hidden state for each element in the batch.
732          Defaults to zero if not provided. If the RNN is bidirectional,
733          num_directions should be 2, else it should be 1.
734
735    Outputs: output, h_n
736        - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
737          containing the output features h_t from the last layer of the GRU,
738          for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
739          given as the input, the output will also be a packed sequence.
740          For the unpacked case, the directions can be separated
741          using ``output.view(seq_len, batch, num_directions, hidden_size)``,
742          with forward and backward being direction `0` and `1` respectively.
743
744          Similarly, the directions can be separated in the packed case.
745        - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
746          containing the hidden state for `t = seq_len`
747
748          Like *output*, the layers can be separated using
749          ``h_n.view(num_layers, num_directions, batch, hidden_size)``.
750
751    Shape:
752        - Input1: :math:`(L, N, H_{in})` tensor containing input features where
753          :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
754        - Input2: :math:`(S, N, H_{out})` tensor
755          containing the initial hidden state for each element in the batch.
756          :math:`H_{out}=\text{hidden\_size}`
757          Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
758          If the RNN is bidirectional, num_directions should be 2, else it should be 1.
759        - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
760        - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
761          for each element in the batch
762
763    Attributes:
764        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
765            (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
766            Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
767        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
768            (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
769        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
770            (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
771        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
772            (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
773
774    .. note::
775        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
776        where :math:`k = \frac{1}{\text{hidden\_size}}`
777
778    .. note::
779        The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
780        In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
781        previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
782        `W` and addition of bias:
783
784        .. math::
785            \begin{aligned}
786                n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
787            \end{aligned}
788
789        This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
790
791        .. math::
792            \begin{aligned}
793                n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
794            \end{aligned}
795
796        This implementation differs on purpose for efficiency.
797
798    .. include:: ../cudnn_persistent_rnn.rst
799
800    Examples::
801
802        >>> # xdoctest: +SKIP
803        >>> rnn = nn.GRU(10, 20, 2)
804        >>> input = torch.randn(5, 3, 10)
805        >>> h0 = torch.randn(2, 3, 20)
806        >>> output, hn = rnn(input, h0)
807    """
808    _FLOAT_MODULE = nn.GRU
809
810    __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
811
812    def __init__(self, *args, **kwargs):
813        super().__init__("GRU", *args, **kwargs)
814
815    def _get_name(self):
816        return "DynamicQuantizedGRU"
817
818    def check_forward_args(
819        self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
820    ) -> None:
821        self.check_input(input, batch_sizes)
822        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
823
824        self.check_hidden_size(
825            hidden, expected_hidden_size, "Expected hidden size {}, got {}"
826        )
827
828    def forward_impl(
829        self,
830        input: Tensor,
831        hx: Optional[Tensor],
832        batch_sizes: Optional[Tensor],
833        max_batch_size: int,
834        sorted_indices: Optional[Tensor],
835    ) -> Tuple[Tensor, Tensor]:
836        if hx is None:
837            num_directions = 2 if self.bidirectional else 1
838            zeros = torch.zeros(
839                self.num_layers * num_directions,
840                max_batch_size,
841                self.hidden_size,
842                dtype=input.dtype,
843                device=input.device,
844            )
845            hx = zeros
846        else:
847            # Each batch of the hidden state should match the input sequence that
848            # the user believes he/she is passing in.
849            hx = self.permute_hidden(hx, sorted_indices)
850
851        self.check_forward_args(input, hx, batch_sizes)
852
853        _all_params = [m.param for m in self._all_weight_values]
854        if batch_sizes is None:
855            result = torch.quantized_gru(
856                input,
857                hx,
858                _all_params,
859                self.bias,
860                self.num_layers,
861                self.dropout,
862                self.training,
863                self.bidirectional,
864                self.batch_first,
865            )
866        else:
867            result = torch.quantized_gru(
868                input,
869                batch_sizes,
870                hx,
871                _all_params,
872                self.bias,
873                self.num_layers,
874                self.dropout,
875                self.training,
876                self.bidirectional,
877            )
878        output = result[0]
879        hidden = result[1]
880
881        return output, hidden
882
883    @torch.jit.export
884    def forward_tensor(
885        self, input: Tensor, hx: Optional[Tensor] = None
886    ) -> Tuple[Tensor, Tensor]:
887        batch_sizes = None
888        max_batch_size = input.size(0) if self.batch_first else input.size(1)
889        sorted_indices = None
890        unsorted_indices = None
891
892        output, hidden = self.forward_impl(
893            input, hx, batch_sizes, max_batch_size, sorted_indices
894        )
895
896        return output, self.permute_hidden(hidden, unsorted_indices)
897
898    @torch.jit.export
899    def forward_packed(
900        self, input: PackedSequence, hx: Optional[Tensor] = None
901    ) -> Tuple[PackedSequence, Tensor]:
902        input_, batch_sizes, sorted_indices, unsorted_indices = input
903        max_batch_size = int(batch_sizes[0])
904        output_, hidden = self.forward_impl(
905            input_, hx, batch_sizes, max_batch_size, sorted_indices
906        )
907
908        output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices)
909        return output, self.permute_hidden(hidden, unsorted_indices)
910
911    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
912        if permutation is None:
913            return hx
914        return _apply_permutation(hx, permutation)
915
916    @torch.jit.ignore
917    def forward(self, input, hx=None):
918        if isinstance(input, PackedSequence):
919            return self.forward_packed(input, hx)
920        else:
921            return self.forward_tensor(input, hx)
922
923    @classmethod
924    def from_float(cls, mod, use_precomputed_fake_quant=False):
925        return super().from_float(
926            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
927        )
928
929    @classmethod
930    def from_reference(cls, ref_mod):
931        assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
932        "exists in LSTM, may need to relax the assumption to support the use case"
933        qmod = cls(
934            ref_mod.input_size,
935            ref_mod.hidden_size,
936            ref_mod.num_layers,
937            ref_mod.bias,
938            ref_mod.batch_first,
939            ref_mod.dropout,
940            ref_mod.bidirectional,
941            # assuming there is layer 0, which should be OK
942            ref_mod.weight_ih_l0_dtype,
943        )
944        qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
945        return qmod
946
947
948class RNNCellBase(torch.nn.Module):
949    # _FLOAT_MODULE = nn.CellRNNBase
950    __constants__ = ["input_size", "hidden_size", "bias"]
951
952    def __init__(
953        self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8
954    ):
955        super().__init__()
956        self.input_size = input_size
957        self.hidden_size = hidden_size
958        self.bias = bias
959        self.weight_dtype = dtype
960        if bias:
961            self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
962            self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
963        else:
964            self.register_parameter("bias_ih", None)
965            self.register_parameter("bias_hh", None)
966
967        weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float)
968        weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float)
969        if dtype == torch.qint8:
970            weight_ih = torch.quantize_per_tensor(
971                weight_ih, scale=1, zero_point=0, dtype=torch.qint8
972            )
973            weight_hh = torch.quantize_per_tensor(
974                weight_hh, scale=1, zero_point=0, dtype=torch.qint8
975            )
976
977        if dtype == torch.qint8:
978            # for each layer, for each direction we need to quantize and pack
979            # weights and pack parameters in this order:
980            #
981            #   w_ih, w_hh
982            packed_weight_ih = torch.ops.quantized.linear_prepack(
983                weight_ih, self.bias_ih
984            )
985            packed_weight_hh = torch.ops.quantized.linear_prepack(
986                weight_hh, self.bias_hh
987            )
988        else:
989            # for each layer, for each direction we need to quantize and pack
990            # weights and pack parameters in this order:
991            #
992            #   packed_ih, packed_hh, b_ih, b_hh
993            packed_weight_ih = torch.ops.quantized.linear_prepack_fp16(
994                weight_ih, self.bias_ih
995            )
996            packed_weight_hh = torch.ops.quantized.linear_prepack_fp16(
997                weight_hh, self.bias_hh
998            )
999
1000        self._packed_weight_ih = packed_weight_ih
1001        self._packed_weight_hh = packed_weight_hh
1002
1003    def _get_name(self):
1004        return "DynamicQuantizedRNNBase"
1005
1006    def extra_repr(self):
1007        s = "{input_size}, {hidden_size}"
1008        if "bias" in self.__dict__ and self.bias is not True:
1009            s += ", bias={bias}"
1010        if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
1011            s += ", nonlinearity={nonlinearity}"
1012        return s.format(**self.__dict__)
1013
1014    def check_forward_input(self, input):
1015        if input.size(1) != self.input_size:
1016            raise RuntimeError(
1017                f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}"
1018            )
1019
1020    def check_forward_hidden(
1021        self, input: Tensor, hx: Tensor, hidden_label: str = ""
1022    ) -> None:
1023        if input.size(0) != hx.size(0):
1024            raise RuntimeError(
1025                f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}"
1026            )
1027
1028        if hx.size(1) != self.hidden_size:
1029            raise RuntimeError(
1030                f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}"
1031            )
1032
1033    @classmethod
1034    def from_float(cls, mod, use_precomputed_fake_quant=False):
1035        assert type(mod) in {
1036            torch.nn.LSTMCell,
1037            torch.nn.GRUCell,
1038            torch.nn.RNNCell,
1039        }, "nn.quantized.dynamic.RNNCellBase.from_float \
1040                                 only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell"
1041        assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
1042
1043        if mod.qconfig is not None and mod.qconfig.weight is not None:
1044            weight_observer_method = mod.qconfig.weight
1045        else:
1046            # We have the circular import issues if we import the qconfig in the beginning of this file:
1047            # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
1048            # import until we need it.
1049            from torch.ao.quantization.qconfig import default_dynamic_qconfig
1050
1051            weight_observer_method = default_dynamic_qconfig.weight
1052
1053        dtype = weight_observer_method().dtype
1054        supported_scalar_types = [torch.qint8, torch.float16]
1055        if dtype not in supported_scalar_types:
1056            raise RuntimeError(
1057                f"Unsupported dtype for dynamic RNN quantization: {dtype}"
1058            )
1059
1060        qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
1061
1062        if type(mod) == torch.nn.LSTMCell:
1063            qRNNCellBase = LSTMCell(
1064                mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
1065            )
1066        elif type(mod) == torch.nn.GRUCell:
1067            qRNNCellBase = GRUCell(
1068                mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
1069            )
1070        elif type(mod) == torch.nn.RNNCell:
1071            qRNNCellBase = RNNCell(
1072                mod.input_size,
1073                mod.hidden_size,
1074                bias=mod.bias,
1075                nonlinearity=mod.nonlinearity,
1076                dtype=dtype,
1077            )
1078        else:
1079            raise NotImplementedError(
1080                "Only LSTMCell, GRUCell and RNNCell \
1081            are supported for QuantizedRNN for now"
1082            )
1083
1084        assert mod.bias
1085
1086        def _observe_and_quantize_weight(weight):
1087            if dtype == torch.qint8:
1088                weight_observer = weight_observer_method()
1089                weight_observer(weight)
1090                qweight = _quantize_weight(weight.float(), weight_observer)
1091                return qweight
1092            else:
1093                return weight.float()
1094
1095        qRNNCellBase._packed_weight_ih = pack_weight_bias(
1096            _observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype
1097        )
1098        qRNNCellBase._packed_weight_hh = pack_weight_bias(
1099            _observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype
1100        )
1101        return qRNNCellBase
1102
1103    @classmethod
1104    def from_reference(cls, ref_mod):
1105        assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih "
1106        "exists in reference module, may need to relax the assumption to support the use case"
1107        if hasattr(ref_mod, "nonlinearity"):
1108            qmod = cls(
1109                ref_mod.input_size,
1110                ref_mod.hidden_size,
1111                ref_mod.bias,
1112                ref_mod.nonlinearity,
1113                dtype=ref_mod.weight_ih_dtype,
1114            )
1115        else:
1116            qmod = cls(
1117                ref_mod.input_size,
1118                ref_mod.hidden_size,
1119                ref_mod.bias,
1120                dtype=ref_mod.weight_ih_dtype,
1121            )
1122        weight_bias_dict = {
1123            "weight": {
1124                "weight_ih": ref_mod.get_quantized_weight_ih(),
1125                "weight_hh": ref_mod.get_quantized_weight_hh(),
1126            },
1127            "bias": {
1128                "bias_ih": ref_mod.bias_ih,
1129                "bias_hh": ref_mod.bias_hh,
1130            },
1131        }
1132        qmod.set_weight_bias(weight_bias_dict)
1133        return qmod
1134
1135    def _weight_bias(self):
1136        # Returns a dict of weights and biases
1137        weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}}
1138        w1, b1 = self._packed_weight_ih.__getstate__()[0]
1139        w2, b2 = self._packed_weight_hh.__getstate__()[0]
1140        # TODO: these can be simplified to one level? e.g. using weight_ih as key
1141        # directly
1142        weight_bias_dict["weight"]["weight_ih"] = w1
1143        weight_bias_dict["weight"]["weight_hh"] = w2
1144        weight_bias_dict["bias"]["bias_ih"] = b1
1145        weight_bias_dict["bias"]["bias_hh"] = b2
1146        return weight_bias_dict
1147
1148    def get_weight(self):
1149        return self._weight_bias()["weight"]
1150
1151    def get_bias(self):
1152        return self._weight_bias()["bias"]
1153
1154    def set_weight_bias(self, weight_bias_dict):
1155        # TODO: these can be simplified to one level? e.g. using weight_ih as key
1156        # directly
1157        self._packed_weight_ih = pack_weight_bias(
1158            weight_bias_dict["weight"]["weight_ih"],
1159            weight_bias_dict["bias"]["bias_ih"],
1160            self.weight_dtype,
1161        )
1162        self._packed_weight_hh = pack_weight_bias(
1163            weight_bias_dict["weight"]["weight_hh"],
1164            weight_bias_dict["bias"]["bias_hh"],
1165            self.weight_dtype,
1166        )
1167
1168    def _save_to_state_dict(self, destination, prefix, keep_vars):
1169        super()._save_to_state_dict(destination, prefix, keep_vars)
1170        destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih
1171        destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh
1172
1173    def _load_from_state_dict(
1174        self,
1175        state_dict,
1176        prefix,
1177        local_metadata,
1178        strict,
1179        missing_keys,
1180        unexpected_keys,
1181        error_msgs,
1182    ):
1183        self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih")
1184        self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh")
1185        super()._load_from_state_dict(
1186            state_dict,
1187            prefix,
1188            local_metadata,
1189            False,
1190            missing_keys,
1191            unexpected_keys,
1192            error_msgs,
1193        )
1194
1195
1196class RNNCell(RNNCellBase):
1197    r"""An Elman RNN cell with tanh or ReLU non-linearity.
1198    A dynamic quantized RNNCell module with floating point tensor as inputs and outputs.
1199    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`,
1200    please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation.
1201
1202    Examples::
1203
1204        >>> # xdoctest: +SKIP
1205        >>> rnn = nn.RNNCell(10, 20)
1206        >>> input = torch.randn(6, 3, 10)
1207        >>> hx = torch.randn(3, 20)
1208        >>> output = []
1209        >>> for i in range(6):
1210        ...     hx = rnn(input[i], hx)
1211        ...     output.append(hx)
1212    """
1213    __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
1214
1215    def __init__(
1216        self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8
1217    ):
1218        super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype)
1219        self.nonlinearity = nonlinearity
1220
1221    def _get_name(self):
1222        return "DynamicQuantizedRNNCell"
1223
1224    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
1225        self.check_forward_input(input)
1226        if hx is None:
1227            hx = torch.zeros(
1228                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1229            )
1230        self.check_forward_hidden(input, hx, "")
1231        if self.nonlinearity == "tanh":
1232            ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(
1233                input,
1234                hx,
1235                self._packed_weight_ih,
1236                self._packed_weight_hh,
1237                self.bias_ih,
1238                self.bias_hh,
1239            )
1240        elif self.nonlinearity == "relu":
1241            ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic(
1242                input,
1243                hx,
1244                self._packed_weight_ih,
1245                self._packed_weight_hh,
1246                self.bias_ih,
1247                self.bias_hh,
1248            )
1249        else:
1250            ret = input  # TODO: remove when jit supports exception flow
1251            raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
1252        return ret
1253
1254    @classmethod
1255    def from_float(cls, mod, use_precomputed_fake_quant=False):
1256        return super().from_float(
1257            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1258        )
1259
1260
1261class LSTMCell(RNNCellBase):
1262    r"""A long short-term memory (LSTM) cell.
1263
1264    A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs.
1265    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`,
1266    please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation.
1267
1268    Examples::
1269
1270        >>> # xdoctest: +SKIP
1271        >>> rnn = nn.LSTMCell(10, 20)
1272        >>> input = torch.randn(6, 3, 10)
1273        >>> hx = torch.randn(3, 20)
1274        >>> cx = torch.randn(3, 20)
1275        >>> output = []
1276        >>> for i in range(6):
1277        ...     hx, cx = rnn(input[i], (hx, cx))
1278        ...     output.append(hx)
1279    """
1280
1281    def __init__(self, *args, **kwargs):
1282        super().__init__(*args, num_chunks=4, **kwargs)  # type: ignore[misc]
1283
1284    def _get_name(self):
1285        return "DynamicQuantizedLSTMCell"
1286
1287    def forward(
1288        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
1289    ) -> Tuple[Tensor, Tensor]:
1290        self.check_forward_input(input)
1291        if hx is None:
1292            zeros = torch.zeros(
1293                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1294            )
1295            hx = (zeros, zeros)
1296        self.check_forward_hidden(input, hx[0], "[0]")
1297        self.check_forward_hidden(input, hx[1], "[1]")
1298        return torch.ops.quantized.quantized_lstm_cell_dynamic(
1299            input,
1300            hx,
1301            self._packed_weight_ih,
1302            self._packed_weight_hh,
1303            self.bias_ih,
1304            self.bias_hh,
1305        )
1306
1307    @classmethod
1308    def from_float(cls, mod, use_precomputed_fake_quant=False):
1309        return super().from_float(
1310            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1311        )
1312
1313
1314class GRUCell(RNNCellBase):
1315    r"""A gated recurrent unit (GRU) cell
1316
1317    A dynamic quantized GRUCell module with floating point tensor as inputs and outputs.
1318    Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`,
1319    please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation.
1320
1321    Examples::
1322
1323        >>> # xdoctest: +SKIP
1324        >>> rnn = nn.GRUCell(10, 20)
1325        >>> input = torch.randn(6, 3, 10)
1326        >>> hx = torch.randn(3, 20)
1327        >>> output = []
1328        >>> for i in range(6):
1329        ...     hx = rnn(input[i], hx)
1330        ...     output.append(hx)
1331    """
1332
1333    def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
1334        super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype)
1335
1336    def _get_name(self):
1337        return "DynamicQuantizedGRUCell"
1338
1339    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
1340        self.check_forward_input(input)
1341        if hx is None:
1342            hx = torch.zeros(
1343                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1344            )
1345        self.check_forward_hidden(input, hx, "")
1346        return torch.ops.quantized.quantized_gru_cell_dynamic(
1347            input,
1348            hx,
1349            self._packed_weight_ih,
1350            self._packed_weight_hh,
1351            self.bias_ih,
1352            self.bias_hh,
1353        )
1354
1355    @classmethod
1356    def from_float(cls, mod, use_precomputed_fake_quant=False):
1357        return super().from_float(
1358            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1359        )
1360