1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import numbers 4import warnings 5from typing_extensions import deprecated 6 7import torch 8import torch.nn as nn 9from torch import Tensor # noqa: F401 10from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401 11from torch.ao.nn.quantized.modules.utils import _quantize_weight 12from torch.nn.utils.rnn import PackedSequence 13 14 15__all__ = [ 16 "pack_weight_bias", 17 "PackedParameter", 18 "RNNBase", 19 "LSTM", 20 "GRU", 21 "RNNCellBase", 22 "RNNCell", 23 "LSTMCell", 24 "GRUCell", 25 "apply_permutation", 26] 27 28 29def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: 30 return tensor.index_select(dim, permutation) 31 32 33@deprecated( 34 "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", 35 category=FutureWarning, 36) 37def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: 38 return _apply_permutation(tensor, permutation, dim) 39 40 41def pack_weight_bias(qweight, bias, dtype): 42 if dtype == torch.qint8: 43 # for each layer, for each direction we need to quantize and pack 44 # weights and pack parameters in this order: 45 # 46 # w_ih, w_hh 47 packed_weight = torch.ops.quantized.linear_prepack(qweight, bias) 48 49 return packed_weight 50 else: 51 # for each layer, for each direction we need to quantize and pack 52 # weights and pack parameters in this order: 53 # 54 # packed_ih, packed_hh, b_ih, b_hh 55 packed_weight = torch.ops.quantized.linear_prepack_fp16(qweight, bias) 56 57 return packed_weight 58 59 60class PackedParameter(torch.nn.Module): 61 def __init__(self, param): 62 super().__init__() 63 self.param = param 64 65 def _save_to_state_dict(self, destination, prefix, keep_vars): 66 super()._save_to_state_dict(destination, prefix, keep_vars) 67 destination[prefix + "param"] = self.param 68 69 def _load_from_state_dict( 70 self, 71 state_dict, 72 prefix, 73 local_metadata, 74 strict, 75 missing_keys, 76 unexpected_keys, 77 error_msgs, 78 ): 79 self.param = state_dict[prefix + "param"] 80 super()._load_from_state_dict( 81 state_dict, 82 prefix, 83 local_metadata, 84 False, 85 missing_keys, 86 unexpected_keys, 87 error_msgs, 88 ) 89 90 91class RNNBase(torch.nn.Module): 92 _FLOAT_MODULE = nn.RNNBase 93 94 _version = 2 95 96 def __init__( 97 self, 98 mode, 99 input_size, 100 hidden_size, 101 num_layers=1, 102 bias=True, 103 batch_first=False, 104 dropout=0.0, 105 bidirectional=False, 106 dtype=torch.qint8, 107 ): 108 super().__init__() 109 110 self.mode = mode 111 self.input_size = input_size 112 self.hidden_size = hidden_size 113 self.num_layers = num_layers 114 self.bias = bias 115 self.batch_first = batch_first 116 self.dropout = float(dropout) 117 self.bidirectional = bidirectional 118 self.dtype = dtype 119 self.version = 2 120 self.training = False 121 num_directions = 2 if bidirectional else 1 122 123 # "type: ignore" is required since ints and Numbers are not fully comparable 124 # https://github.com/python/mypy/issues/8566 125 if ( 126 not isinstance(dropout, numbers.Number) 127 or not 0 <= dropout <= 1 # type: ignore[operator] 128 or isinstance(dropout, bool) 129 ): 130 raise ValueError( 131 "dropout should be a number in range [0, 1] " 132 "representing the probability of an element being " 133 "zeroed" 134 ) 135 if dropout > 0 and num_layers == 1: # type: ignore[operator] 136 warnings.warn( 137 "dropout option adds dropout after all but last " 138 "recurrent layer, so non-zero dropout expects " 139 f"num_layers greater than 1, but got dropout={dropout} and " 140 f"num_layers={num_layers}" 141 ) 142 143 if mode == "LSTM": 144 gate_size = 4 * hidden_size 145 elif mode == "GRU": 146 gate_size = 3 * hidden_size 147 else: 148 raise ValueError("Unrecognized RNN mode: " + mode) 149 150 _all_weight_values = [] 151 for layer in range(num_layers): 152 for direction in range(num_directions): 153 layer_input_size = ( 154 input_size if layer == 0 else hidden_size * num_directions 155 ) 156 157 w_ih = torch.randn(gate_size, layer_input_size).to(torch.float) 158 w_hh = torch.randn(gate_size, hidden_size).to(torch.float) 159 b_ih = torch.randn(gate_size).to(torch.float) 160 b_hh = torch.randn(gate_size).to(torch.float) 161 if dtype == torch.qint8: 162 w_ih = torch.quantize_per_tensor( 163 w_ih, scale=0.1, zero_point=0, dtype=torch.qint8 164 ) 165 w_hh = torch.quantize_per_tensor( 166 w_hh, scale=0.1, zero_point=0, dtype=torch.qint8 167 ) 168 packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) 169 packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) 170 if self.version is None or self.version < 2: 171 cell_params = ( 172 torch.ops.quantized.make_quantized_cell_params_dynamic( 173 packed_ih, packed_hh, b_ih, b_hh 174 ) 175 ) 176 else: 177 cell_params = ( 178 torch.ops.quantized.make_quantized_cell_params_dynamic( 179 packed_ih, packed_hh, b_ih, b_hh, True 180 ) 181 ) 182 else: 183 packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) 184 packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) 185 cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( 186 packed_ih, packed_hh 187 ) 188 189 _all_weight_values.append(PackedParameter(cell_params)) 190 self._all_weight_values = torch.nn.ModuleList(_all_weight_values) 191 192 def _get_name(self): 193 return "DynamicQuantizedRNN" 194 195 def extra_repr(self): 196 s = "{input_size}, {hidden_size}" 197 if self.num_layers != 1: 198 s += ", num_layers={num_layers}" 199 if self.bias is not True: 200 s += ", bias={bias}" 201 if self.batch_first is not False: 202 s += ", batch_first={batch_first}" 203 if self.dropout != 0: 204 s += ", dropout={dropout}" 205 if self.bidirectional is not False: 206 s += ", bidirectional={bidirectional}" 207 return s.format(**self.__dict__) 208 209 def __repr__(self): 210 # We don't want to show `ModuleList` children, hence custom 211 # `__repr__`. This is the same as nn.Module.__repr__, except the check 212 # for the `PackedParameter` and `nn.ModuleList`. 213 # You should still override `extra_repr` to add more info. 214 extra_lines = [] 215 extra_repr = self.extra_repr() 216 # empty string will be split into list [''] 217 if extra_repr: 218 extra_lines = extra_repr.split("\n") 219 child_lines = [] 220 for key, module in self._modules.items(): 221 if isinstance(module, (PackedParameter, nn.ModuleList)): 222 continue 223 mod_str = repr(module) 224 mod_str = nn.modules.module._addindent(mod_str, 2) 225 child_lines.append("(" + key + "): " + mod_str) 226 lines = extra_lines + child_lines 227 228 main_str = self._get_name() + "(" 229 if lines: 230 # simple one-liner info, which most builtin Modules will use 231 if len(extra_lines) == 1 and not child_lines: 232 main_str += extra_lines[0] 233 else: 234 main_str += "\n " + "\n ".join(lines) + "\n" 235 236 main_str += ")" 237 return main_str 238 239 def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: 240 expected_input_dim = 2 if batch_sizes is not None else 3 241 if input.dim() != expected_input_dim: 242 raise RuntimeError( 243 f"input must have {expected_input_dim} dimensions, got {input.dim()}" 244 ) 245 if self.input_size != input.size(-1): 246 raise RuntimeError( 247 f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" 248 ) 249 250 def get_expected_hidden_size( 251 self, input: Tensor, batch_sizes: Optional[Tensor] 252 ) -> Tuple[int, int, int]: 253 if batch_sizes is not None: 254 mini_batch = int(batch_sizes[0]) 255 else: 256 mini_batch = input.size(0) if self.batch_first else input.size(1) 257 num_directions = 2 if self.bidirectional else 1 258 expected_hidden_size = ( 259 self.num_layers * num_directions, 260 mini_batch, 261 self.hidden_size, 262 ) 263 return expected_hidden_size 264 265 def check_hidden_size( 266 self, 267 hx: Tensor, 268 expected_hidden_size: Tuple[int, int, int], 269 msg: str = "Expected hidden size {}, got {}", 270 ) -> None: 271 if hx.size() != expected_hidden_size: 272 raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) 273 274 def check_forward_args( 275 self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] 276 ) -> None: 277 self.check_input(input, batch_sizes) 278 expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) 279 self.check_hidden_size( 280 hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" 281 ) 282 283 def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: 284 if permutation is None: 285 return hx 286 return _apply_permutation(hx, permutation) 287 288 def _load_from_state_dict( 289 self, 290 state_dict, 291 prefix, 292 local_metadata, 293 strict, 294 missing_keys, 295 unexpected_keys, 296 error_msgs, 297 ): 298 version = local_metadata.get("version", None) 299 self.version = version 300 super()._load_from_state_dict( 301 state_dict, 302 prefix, 303 local_metadata, 304 False, 305 missing_keys, 306 unexpected_keys, 307 error_msgs, 308 ) 309 310 def set_weight_bias(self, weight_bias_dict): 311 def weight_bias_name(ihhh, layer, suffix): 312 weight_name = f"weight_{ihhh}_l{layer}{suffix}" 313 bias_name = f"bias_{ihhh}_l{layer}{suffix}" 314 return weight_name, bias_name 315 316 num_directions = 2 if self.bidirectional else 1 317 # TODO: dedup with __init__ of RNNBase 318 _all_weight_values = [] 319 for layer in range(self.num_layers): 320 for direction in range(num_directions): 321 suffix = "_reverse" if direction == 1 else "" 322 w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix) 323 w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix) 324 w_ih = weight_bias_dict[w_ih_name] 325 b_ih = weight_bias_dict[b_ih_name] 326 w_hh = weight_bias_dict[w_hh_name] 327 b_hh = weight_bias_dict[b_hh_name] 328 if w_ih.dtype == torch.qint8: 329 packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih) 330 packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh) 331 if self.version is None or self.version < 2: 332 cell_params = ( 333 torch.ops.quantized.make_quantized_cell_params_dynamic( 334 packed_ih, packed_hh, b_ih, b_hh 335 ) 336 ) 337 else: 338 cell_params = ( 339 torch.ops.quantized.make_quantized_cell_params_dynamic( 340 packed_ih, packed_hh, b_ih, b_hh, True 341 ) 342 ) 343 else: 344 packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih) 345 packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh) 346 cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( 347 packed_ih, packed_hh 348 ) 349 350 _all_weight_values.append(PackedParameter(cell_params)) 351 self._all_weight_values = torch.nn.ModuleList(_all_weight_values) 352 353 @classmethod 354 def from_float(cls, mod, use_precomputed_fake_quant=False): 355 assert type(mod) in { 356 torch.nn.LSTM, 357 torch.nn.GRU, 358 }, "nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU" 359 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 360 361 if mod.qconfig is not None and mod.qconfig.weight is not None: 362 weight_observer_method = mod.qconfig.weight 363 else: 364 # We have the circular import issues if we import the qconfig in the beginning of this file: 365 # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the 366 # import until we need it. 367 from torch.ao.quantization.qconfig import default_dynamic_qconfig 368 369 weight_observer_method = default_dynamic_qconfig.weight 370 371 dtype = weight_observer_method().dtype 372 supported_scalar_types = [torch.qint8, torch.float16] 373 if dtype not in supported_scalar_types: 374 raise RuntimeError( 375 f"Unsupported dtype for dynamic RNN quantization: {dtype}" 376 ) 377 # RNNBase can be either LSTM or GRU 378 qRNNBase: Union[LSTM, GRU] 379 if mod.mode == "LSTM": 380 qRNNBase = LSTM( 381 mod.input_size, 382 mod.hidden_size, 383 mod.num_layers, 384 mod.bias, 385 mod.batch_first, 386 mod.dropout, 387 mod.bidirectional, 388 dtype, 389 ) 390 elif mod.mode == "GRU": 391 qRNNBase = GRU( 392 mod.input_size, 393 mod.hidden_size, 394 mod.num_layers, 395 mod.bias, 396 mod.batch_first, 397 mod.dropout, 398 mod.bidirectional, 399 dtype, 400 ) 401 else: 402 raise NotImplementedError( 403 "Only LSTM/GRU is supported for QuantizedRNN for now" 404 ) 405 406 num_directions = 2 if mod.bidirectional else 1 407 408 assert mod.bias 409 410 _all_weight_values = [] 411 for layer in range(qRNNBase.num_layers): 412 for direction in range(num_directions): 413 suffix = "_reverse" if direction == 1 else "" 414 415 def retrieve_weight_bias(ihhh): 416 weight_name = f"weight_{ihhh}_l{layer}{suffix}" 417 bias_name = f"bias_{ihhh}_l{layer}{suffix}" 418 weight = getattr(mod, weight_name) 419 bias = getattr(mod, bias_name) 420 return weight, bias 421 422 weight_ih, bias_ih = retrieve_weight_bias("ih") 423 weight_hh, bias_hh = retrieve_weight_bias("hh") 424 425 if dtype == torch.qint8: 426 427 def quantize_and_pack(w, b): 428 weight_observer = weight_observer_method() 429 weight_observer(w) 430 qweight = _quantize_weight(w.float(), weight_observer) 431 packed_weight = torch.ops.quantized.linear_prepack(qweight, b) 432 return packed_weight 433 434 packed_ih = quantize_and_pack(weight_ih, bias_ih) 435 packed_hh = quantize_and_pack(weight_hh, bias_hh) 436 if qRNNBase.version is None or qRNNBase.version < 2: 437 cell_params = ( 438 torch.ops.quantized.make_quantized_cell_params_dynamic( 439 packed_ih, packed_hh, bias_ih, bias_hh 440 ) 441 ) 442 else: 443 cell_params = ( 444 torch.ops.quantized.make_quantized_cell_params_dynamic( 445 packed_ih, packed_hh, bias_ih, bias_hh, True 446 ) 447 ) 448 449 elif dtype == torch.float16: 450 packed_ih = torch.ops.quantized.linear_prepack_fp16( 451 weight_ih.float(), bias_ih 452 ) 453 packed_hh = torch.ops.quantized.linear_prepack_fp16( 454 weight_hh.float(), bias_hh 455 ) 456 457 cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( 458 packed_ih, packed_hh 459 ) 460 else: 461 raise RuntimeError( 462 "Unsupported dtype specified for dynamic quantized LSTM!" 463 ) 464 465 _all_weight_values.append(PackedParameter(cell_params)) 466 qRNNBase._all_weight_values = torch.nn.ModuleList(_all_weight_values) 467 468 return qRNNBase 469 470 def _weight_bias(self): 471 # Returns a dict of weights and biases 472 weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} 473 count = 0 474 num_directions = 2 if self.bidirectional else 1 475 for layer in range(self.num_layers): 476 for direction in range(num_directions): 477 suffix = "_reverse" if direction == 1 else "" 478 key_name1 = f"weight_ih_l{layer}{suffix}" 479 key_name2 = f"weight_hh_l{layer}{suffix}" 480 # packed weights are part of torchbind class, CellParamsSerializationType 481 # Within the packed weight class, the weight and bias are accessible as Tensors 482 packed_weight_bias = self._all_weight_values[ 483 count 484 ].param.__getstate__()[0][4] 485 weight_bias_dict["weight"][key_name1] = packed_weight_bias[ 486 0 487 ].__getstate__()[0][0] 488 weight_bias_dict["weight"][key_name2] = packed_weight_bias[ 489 1 490 ].__getstate__()[0][0] 491 key_name1 = f"bias_ih_l{layer}{suffix}" 492 key_name2 = f"bias_hh_l{layer}{suffix}" 493 weight_bias_dict["bias"][key_name1] = packed_weight_bias[ 494 0 495 ].__getstate__()[0][1] 496 weight_bias_dict["bias"][key_name2] = packed_weight_bias[ 497 1 498 ].__getstate__()[0][1] 499 count = count + 1 500 return weight_bias_dict 501 502 def get_weight(self): 503 return self._weight_bias()["weight"] 504 505 def get_bias(self): 506 return self._weight_bias()["bias"] 507 508 509class LSTM(RNNBase): 510 r""" 511 A dynamic quantized LSTM module with floating point tensor as inputs and outputs. 512 We adopt the same interface as `torch.nn.LSTM`, please see 513 https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. 514 515 Examples:: 516 517 >>> # xdoctest: +SKIP 518 >>> rnn = nn.LSTM(10, 20, 2) 519 >>> input = torch.randn(5, 3, 10) 520 >>> h0 = torch.randn(2, 3, 20) 521 >>> c0 = torch.randn(2, 3, 20) 522 >>> output, (hn, cn) = rnn(input, (h0, c0)) 523 """ 524 _FLOAT_MODULE = nn.LSTM 525 526 __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} 527 528 def __init__(self, *args, **kwargs): 529 super().__init__("LSTM", *args, **kwargs) 530 531 def _get_name(self): 532 return "DynamicQuantizedLSTM" 533 534 def forward_impl( 535 self, 536 input: Tensor, 537 hx: Optional[Tuple[Tensor, Tensor]], 538 batch_sizes: Optional[Tensor], 539 max_batch_size: int, 540 sorted_indices: Optional[Tensor], 541 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 542 if hx is None: 543 num_directions = 2 if self.bidirectional else 1 544 zeros = torch.zeros( 545 self.num_layers * num_directions, 546 max_batch_size, 547 self.hidden_size, 548 dtype=input.dtype, 549 device=input.device, 550 ) 551 hx = (zeros, zeros) 552 else: 553 # Each batch of the hidden state should match the input sequence that 554 # the user believes he/she is passing in. 555 hx = self.permute_hidden(hx, sorted_indices) 556 557 self.check_forward_args(input, hx, batch_sizes) 558 559 _all_params = [m.param for m in self._all_weight_values] 560 if batch_sizes is None: 561 result = torch.quantized_lstm( 562 input, 563 hx, 564 _all_params, 565 self.bias, 566 self.num_layers, 567 float(self.dropout), 568 self.training, 569 self.bidirectional, 570 self.batch_first, 571 dtype=self.dtype, 572 use_dynamic=True, 573 ) 574 else: 575 result = torch.quantized_lstm( 576 input, 577 batch_sizes, 578 hx, 579 _all_params, 580 self.bias, 581 self.num_layers, 582 float(self.dropout), 583 self.training, 584 self.bidirectional, 585 dtype=self.dtype, 586 use_dynamic=True, 587 ) 588 output = result[0] 589 hidden = result[1:] 590 591 return output, hidden 592 593 @torch.jit.export 594 def forward_tensor( 595 self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 596 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 597 batch_sizes = None 598 max_batch_size = input.size(0) if self.batch_first else input.size(1) 599 sorted_indices = None 600 unsorted_indices = None 601 602 output, hidden = self.forward_impl( 603 input, hx, batch_sizes, max_batch_size, sorted_indices 604 ) 605 606 return output, self.permute_hidden(hidden, unsorted_indices) 607 608 @torch.jit.export 609 def forward_packed( 610 self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None 611 ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: 612 input_, batch_sizes, sorted_indices, unsorted_indices = input 613 max_batch_size = int(batch_sizes[0]) 614 615 output_, hidden = self.forward_impl( 616 input_, hx, batch_sizes, max_batch_size, sorted_indices 617 ) 618 619 output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) 620 return output, self.permute_hidden(hidden, unsorted_indices) 621 622 # "type: ignore" is required due to issue #43072 623 def permute_hidden( # type: ignore[override] 624 self, 625 hx: Tuple[Tensor, Tensor], 626 permutation: Optional[Tensor], 627 ) -> Tuple[Tensor, Tensor]: 628 if permutation is None: 629 return hx 630 return _apply_permutation(hx[0], permutation), _apply_permutation( 631 hx[1], permutation 632 ) 633 634 # "type: ignore" is required due to issue #43072 635 def check_forward_args( # type: ignore[override] 636 self, 637 input: Tensor, 638 hidden: Tuple[Tensor, Tensor], 639 batch_sizes: Optional[Tensor], 640 ) -> None: 641 self.check_input(input, batch_sizes) 642 expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) 643 644 self.check_hidden_size( 645 hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" 646 ) 647 self.check_hidden_size( 648 hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" 649 ) 650 651 @torch.jit.ignore 652 def forward(self, input, hx=None): 653 if isinstance(input, PackedSequence): 654 return self.forward_packed(input, hx) 655 else: 656 return self.forward_tensor(input, hx) 657 658 @classmethod 659 def from_float(cls, mod, use_precomputed_fake_quant=False): 660 return super().from_float( 661 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 662 ) 663 664 @classmethod 665 def from_reference(cls, ref_mod): 666 assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " 667 "exists in LSTM, may need to relax the assumption to support the use case" 668 qmod = cls( 669 ref_mod.input_size, 670 ref_mod.hidden_size, 671 ref_mod.num_layers, 672 ref_mod.bias, 673 ref_mod.batch_first, 674 ref_mod.dropout, 675 ref_mod.bidirectional, 676 # assuming there is layer 0, which should be OK 677 ref_mod.weight_ih_l0_dtype, 678 ) 679 qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) 680 return qmod 681 682 683class GRU(RNNBase): 684 r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. 685 686 687 For each element in the input sequence, each layer computes the following 688 function: 689 690 .. math:: 691 \begin{array}{ll} 692 r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ 693 z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ 694 n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ 695 h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} 696 \end{array} 697 698 where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input 699 at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer 700 at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, 701 :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. 702 :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. 703 704 In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer 705 (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by 706 dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random 707 variable which is :math:`0` with probability :attr:`dropout`. 708 709 Args: 710 input_size: The number of expected features in the input `x` 711 hidden_size: The number of features in the hidden state `h` 712 num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` 713 would mean stacking two GRUs together to form a `stacked GRU`, 714 with the second GRU taking in outputs of the first GRU and 715 computing the final results. Default: 1 716 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. 717 Default: ``True`` 718 batch_first: If ``True``, then the input and output tensors are provided 719 as (batch, seq, feature). Default: ``False`` 720 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each 721 GRU layer except the last layer, with dropout probability equal to 722 :attr:`dropout`. Default: 0 723 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` 724 725 Inputs: input, h_0 726 - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features 727 of the input sequence. The input can also be a packed variable length 728 sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` 729 for details. 730 - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor 731 containing the initial hidden state for each element in the batch. 732 Defaults to zero if not provided. If the RNN is bidirectional, 733 num_directions should be 2, else it should be 1. 734 735 Outputs: output, h_n 736 - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor 737 containing the output features h_t from the last layer of the GRU, 738 for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been 739 given as the input, the output will also be a packed sequence. 740 For the unpacked case, the directions can be separated 741 using ``output.view(seq_len, batch, num_directions, hidden_size)``, 742 with forward and backward being direction `0` and `1` respectively. 743 744 Similarly, the directions can be separated in the packed case. 745 - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor 746 containing the hidden state for `t = seq_len` 747 748 Like *output*, the layers can be separated using 749 ``h_n.view(num_layers, num_directions, batch, hidden_size)``. 750 751 Shape: 752 - Input1: :math:`(L, N, H_{in})` tensor containing input features where 753 :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. 754 - Input2: :math:`(S, N, H_{out})` tensor 755 containing the initial hidden state for each element in the batch. 756 :math:`H_{out}=\text{hidden\_size}` 757 Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` 758 If the RNN is bidirectional, num_directions should be 2, else it should be 1. 759 - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` 760 - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state 761 for each element in the batch 762 763 Attributes: 764 weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer 765 (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. 766 Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` 767 weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer 768 (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` 769 bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer 770 (b_ir|b_iz|b_in), of shape `(3*hidden_size)` 771 bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer 772 (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` 773 774 .. note:: 775 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 776 where :math:`k = \frac{1}{\text{hidden\_size}}` 777 778 .. note:: 779 The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. 780 In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the 781 previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix 782 `W` and addition of bias: 783 784 .. math:: 785 \begin{aligned} 786 n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) 787 \end{aligned} 788 789 This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` 790 791 .. math:: 792 \begin{aligned} 793 n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) 794 \end{aligned} 795 796 This implementation differs on purpose for efficiency. 797 798 .. include:: ../cudnn_persistent_rnn.rst 799 800 Examples:: 801 802 >>> # xdoctest: +SKIP 803 >>> rnn = nn.GRU(10, 20, 2) 804 >>> input = torch.randn(5, 3, 10) 805 >>> h0 = torch.randn(2, 3, 20) 806 >>> output, hn = rnn(input, h0) 807 """ 808 _FLOAT_MODULE = nn.GRU 809 810 __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} 811 812 def __init__(self, *args, **kwargs): 813 super().__init__("GRU", *args, **kwargs) 814 815 def _get_name(self): 816 return "DynamicQuantizedGRU" 817 818 def check_forward_args( 819 self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] 820 ) -> None: 821 self.check_input(input, batch_sizes) 822 expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) 823 824 self.check_hidden_size( 825 hidden, expected_hidden_size, "Expected hidden size {}, got {}" 826 ) 827 828 def forward_impl( 829 self, 830 input: Tensor, 831 hx: Optional[Tensor], 832 batch_sizes: Optional[Tensor], 833 max_batch_size: int, 834 sorted_indices: Optional[Tensor], 835 ) -> Tuple[Tensor, Tensor]: 836 if hx is None: 837 num_directions = 2 if self.bidirectional else 1 838 zeros = torch.zeros( 839 self.num_layers * num_directions, 840 max_batch_size, 841 self.hidden_size, 842 dtype=input.dtype, 843 device=input.device, 844 ) 845 hx = zeros 846 else: 847 # Each batch of the hidden state should match the input sequence that 848 # the user believes he/she is passing in. 849 hx = self.permute_hidden(hx, sorted_indices) 850 851 self.check_forward_args(input, hx, batch_sizes) 852 853 _all_params = [m.param for m in self._all_weight_values] 854 if batch_sizes is None: 855 result = torch.quantized_gru( 856 input, 857 hx, 858 _all_params, 859 self.bias, 860 self.num_layers, 861 self.dropout, 862 self.training, 863 self.bidirectional, 864 self.batch_first, 865 ) 866 else: 867 result = torch.quantized_gru( 868 input, 869 batch_sizes, 870 hx, 871 _all_params, 872 self.bias, 873 self.num_layers, 874 self.dropout, 875 self.training, 876 self.bidirectional, 877 ) 878 output = result[0] 879 hidden = result[1] 880 881 return output, hidden 882 883 @torch.jit.export 884 def forward_tensor( 885 self, input: Tensor, hx: Optional[Tensor] = None 886 ) -> Tuple[Tensor, Tensor]: 887 batch_sizes = None 888 max_batch_size = input.size(0) if self.batch_first else input.size(1) 889 sorted_indices = None 890 unsorted_indices = None 891 892 output, hidden = self.forward_impl( 893 input, hx, batch_sizes, max_batch_size, sorted_indices 894 ) 895 896 return output, self.permute_hidden(hidden, unsorted_indices) 897 898 @torch.jit.export 899 def forward_packed( 900 self, input: PackedSequence, hx: Optional[Tensor] = None 901 ) -> Tuple[PackedSequence, Tensor]: 902 input_, batch_sizes, sorted_indices, unsorted_indices = input 903 max_batch_size = int(batch_sizes[0]) 904 output_, hidden = self.forward_impl( 905 input_, hx, batch_sizes, max_batch_size, sorted_indices 906 ) 907 908 output = PackedSequence(output_, batch_sizes, sorted_indices, unsorted_indices) 909 return output, self.permute_hidden(hidden, unsorted_indices) 910 911 def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: 912 if permutation is None: 913 return hx 914 return _apply_permutation(hx, permutation) 915 916 @torch.jit.ignore 917 def forward(self, input, hx=None): 918 if isinstance(input, PackedSequence): 919 return self.forward_packed(input, hx) 920 else: 921 return self.forward_tensor(input, hx) 922 923 @classmethod 924 def from_float(cls, mod, use_precomputed_fake_quant=False): 925 return super().from_float( 926 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 927 ) 928 929 @classmethod 930 def from_reference(cls, ref_mod): 931 assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 " 932 "exists in LSTM, may need to relax the assumption to support the use case" 933 qmod = cls( 934 ref_mod.input_size, 935 ref_mod.hidden_size, 936 ref_mod.num_layers, 937 ref_mod.bias, 938 ref_mod.batch_first, 939 ref_mod.dropout, 940 ref_mod.bidirectional, 941 # assuming there is layer 0, which should be OK 942 ref_mod.weight_ih_l0_dtype, 943 ) 944 qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict()) 945 return qmod 946 947 948class RNNCellBase(torch.nn.Module): 949 # _FLOAT_MODULE = nn.CellRNNBase 950 __constants__ = ["input_size", "hidden_size", "bias"] 951 952 def __init__( 953 self, input_size, hidden_size, bias=True, num_chunks=4, dtype=torch.qint8 954 ): 955 super().__init__() 956 self.input_size = input_size 957 self.hidden_size = hidden_size 958 self.bias = bias 959 self.weight_dtype = dtype 960 if bias: 961 self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) 962 self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float) 963 else: 964 self.register_parameter("bias_ih", None) 965 self.register_parameter("bias_hh", None) 966 967 weight_ih = torch.randn(num_chunks * hidden_size, input_size).to(torch.float) 968 weight_hh = torch.randn(num_chunks * hidden_size, hidden_size).to(torch.float) 969 if dtype == torch.qint8: 970 weight_ih = torch.quantize_per_tensor( 971 weight_ih, scale=1, zero_point=0, dtype=torch.qint8 972 ) 973 weight_hh = torch.quantize_per_tensor( 974 weight_hh, scale=1, zero_point=0, dtype=torch.qint8 975 ) 976 977 if dtype == torch.qint8: 978 # for each layer, for each direction we need to quantize and pack 979 # weights and pack parameters in this order: 980 # 981 # w_ih, w_hh 982 packed_weight_ih = torch.ops.quantized.linear_prepack( 983 weight_ih, self.bias_ih 984 ) 985 packed_weight_hh = torch.ops.quantized.linear_prepack( 986 weight_hh, self.bias_hh 987 ) 988 else: 989 # for each layer, for each direction we need to quantize and pack 990 # weights and pack parameters in this order: 991 # 992 # packed_ih, packed_hh, b_ih, b_hh 993 packed_weight_ih = torch.ops.quantized.linear_prepack_fp16( 994 weight_ih, self.bias_ih 995 ) 996 packed_weight_hh = torch.ops.quantized.linear_prepack_fp16( 997 weight_hh, self.bias_hh 998 ) 999 1000 self._packed_weight_ih = packed_weight_ih 1001 self._packed_weight_hh = packed_weight_hh 1002 1003 def _get_name(self): 1004 return "DynamicQuantizedRNNBase" 1005 1006 def extra_repr(self): 1007 s = "{input_size}, {hidden_size}" 1008 if "bias" in self.__dict__ and self.bias is not True: 1009 s += ", bias={bias}" 1010 if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": 1011 s += ", nonlinearity={nonlinearity}" 1012 return s.format(**self.__dict__) 1013 1014 def check_forward_input(self, input): 1015 if input.size(1) != self.input_size: 1016 raise RuntimeError( 1017 f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" 1018 ) 1019 1020 def check_forward_hidden( 1021 self, input: Tensor, hx: Tensor, hidden_label: str = "" 1022 ) -> None: 1023 if input.size(0) != hx.size(0): 1024 raise RuntimeError( 1025 f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" 1026 ) 1027 1028 if hx.size(1) != self.hidden_size: 1029 raise RuntimeError( 1030 f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" 1031 ) 1032 1033 @classmethod 1034 def from_float(cls, mod, use_precomputed_fake_quant=False): 1035 assert type(mod) in { 1036 torch.nn.LSTMCell, 1037 torch.nn.GRUCell, 1038 torch.nn.RNNCell, 1039 }, "nn.quantized.dynamic.RNNCellBase.from_float \ 1040 only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell" 1041 assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" 1042 1043 if mod.qconfig is not None and mod.qconfig.weight is not None: 1044 weight_observer_method = mod.qconfig.weight 1045 else: 1046 # We have the circular import issues if we import the qconfig in the beginning of this file: 1047 # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the 1048 # import until we need it. 1049 from torch.ao.quantization.qconfig import default_dynamic_qconfig 1050 1051 weight_observer_method = default_dynamic_qconfig.weight 1052 1053 dtype = weight_observer_method().dtype 1054 supported_scalar_types = [torch.qint8, torch.float16] 1055 if dtype not in supported_scalar_types: 1056 raise RuntimeError( 1057 f"Unsupported dtype for dynamic RNN quantization: {dtype}" 1058 ) 1059 1060 qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] 1061 1062 if type(mod) == torch.nn.LSTMCell: 1063 qRNNCellBase = LSTMCell( 1064 mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype 1065 ) 1066 elif type(mod) == torch.nn.GRUCell: 1067 qRNNCellBase = GRUCell( 1068 mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype 1069 ) 1070 elif type(mod) == torch.nn.RNNCell: 1071 qRNNCellBase = RNNCell( 1072 mod.input_size, 1073 mod.hidden_size, 1074 bias=mod.bias, 1075 nonlinearity=mod.nonlinearity, 1076 dtype=dtype, 1077 ) 1078 else: 1079 raise NotImplementedError( 1080 "Only LSTMCell, GRUCell and RNNCell \ 1081 are supported for QuantizedRNN for now" 1082 ) 1083 1084 assert mod.bias 1085 1086 def _observe_and_quantize_weight(weight): 1087 if dtype == torch.qint8: 1088 weight_observer = weight_observer_method() 1089 weight_observer(weight) 1090 qweight = _quantize_weight(weight.float(), weight_observer) 1091 return qweight 1092 else: 1093 return weight.float() 1094 1095 qRNNCellBase._packed_weight_ih = pack_weight_bias( 1096 _observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype 1097 ) 1098 qRNNCellBase._packed_weight_hh = pack_weight_bias( 1099 _observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype 1100 ) 1101 return qRNNCellBase 1102 1103 @classmethod 1104 def from_reference(cls, ref_mod): 1105 assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih " 1106 "exists in reference module, may need to relax the assumption to support the use case" 1107 if hasattr(ref_mod, "nonlinearity"): 1108 qmod = cls( 1109 ref_mod.input_size, 1110 ref_mod.hidden_size, 1111 ref_mod.bias, 1112 ref_mod.nonlinearity, 1113 dtype=ref_mod.weight_ih_dtype, 1114 ) 1115 else: 1116 qmod = cls( 1117 ref_mod.input_size, 1118 ref_mod.hidden_size, 1119 ref_mod.bias, 1120 dtype=ref_mod.weight_ih_dtype, 1121 ) 1122 weight_bias_dict = { 1123 "weight": { 1124 "weight_ih": ref_mod.get_quantized_weight_ih(), 1125 "weight_hh": ref_mod.get_quantized_weight_hh(), 1126 }, 1127 "bias": { 1128 "bias_ih": ref_mod.bias_ih, 1129 "bias_hh": ref_mod.bias_hh, 1130 }, 1131 } 1132 qmod.set_weight_bias(weight_bias_dict) 1133 return qmod 1134 1135 def _weight_bias(self): 1136 # Returns a dict of weights and biases 1137 weight_bias_dict: Dict[str, Dict] = {"weight": {}, "bias": {}} 1138 w1, b1 = self._packed_weight_ih.__getstate__()[0] 1139 w2, b2 = self._packed_weight_hh.__getstate__()[0] 1140 # TODO: these can be simplified to one level? e.g. using weight_ih as key 1141 # directly 1142 weight_bias_dict["weight"]["weight_ih"] = w1 1143 weight_bias_dict["weight"]["weight_hh"] = w2 1144 weight_bias_dict["bias"]["bias_ih"] = b1 1145 weight_bias_dict["bias"]["bias_hh"] = b2 1146 return weight_bias_dict 1147 1148 def get_weight(self): 1149 return self._weight_bias()["weight"] 1150 1151 def get_bias(self): 1152 return self._weight_bias()["bias"] 1153 1154 def set_weight_bias(self, weight_bias_dict): 1155 # TODO: these can be simplified to one level? e.g. using weight_ih as key 1156 # directly 1157 self._packed_weight_ih = pack_weight_bias( 1158 weight_bias_dict["weight"]["weight_ih"], 1159 weight_bias_dict["bias"]["bias_ih"], 1160 self.weight_dtype, 1161 ) 1162 self._packed_weight_hh = pack_weight_bias( 1163 weight_bias_dict["weight"]["weight_hh"], 1164 weight_bias_dict["bias"]["bias_hh"], 1165 self.weight_dtype, 1166 ) 1167 1168 def _save_to_state_dict(self, destination, prefix, keep_vars): 1169 super()._save_to_state_dict(destination, prefix, keep_vars) 1170 destination[prefix + "_packed_weight_ih"] = self._packed_weight_ih 1171 destination[prefix + "_packed_weight_hh"] = self._packed_weight_hh 1172 1173 def _load_from_state_dict( 1174 self, 1175 state_dict, 1176 prefix, 1177 local_metadata, 1178 strict, 1179 missing_keys, 1180 unexpected_keys, 1181 error_msgs, 1182 ): 1183 self._packed_weight_ih = state_dict.pop(prefix + "_packed_weight_ih") 1184 self._packed_weight_hh = state_dict.pop(prefix + "_packed_weight_hh") 1185 super()._load_from_state_dict( 1186 state_dict, 1187 prefix, 1188 local_metadata, 1189 False, 1190 missing_keys, 1191 unexpected_keys, 1192 error_msgs, 1193 ) 1194 1195 1196class RNNCell(RNNCellBase): 1197 r"""An Elman RNN cell with tanh or ReLU non-linearity. 1198 A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. 1199 Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, 1200 please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. 1201 1202 Examples:: 1203 1204 >>> # xdoctest: +SKIP 1205 >>> rnn = nn.RNNCell(10, 20) 1206 >>> input = torch.randn(6, 3, 10) 1207 >>> hx = torch.randn(3, 20) 1208 >>> output = [] 1209 >>> for i in range(6): 1210 ... hx = rnn(input[i], hx) 1211 ... output.append(hx) 1212 """ 1213 __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] 1214 1215 def __init__( 1216 self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtype=torch.qint8 1217 ): 1218 super().__init__(input_size, hidden_size, bias, num_chunks=1, dtype=dtype) 1219 self.nonlinearity = nonlinearity 1220 1221 def _get_name(self): 1222 return "DynamicQuantizedRNNCell" 1223 1224 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 1225 self.check_forward_input(input) 1226 if hx is None: 1227 hx = torch.zeros( 1228 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1229 ) 1230 self.check_forward_hidden(input, hx, "") 1231 if self.nonlinearity == "tanh": 1232 ret = torch.ops.quantized.quantized_rnn_tanh_cell_dynamic( 1233 input, 1234 hx, 1235 self._packed_weight_ih, 1236 self._packed_weight_hh, 1237 self.bias_ih, 1238 self.bias_hh, 1239 ) 1240 elif self.nonlinearity == "relu": 1241 ret = torch.ops.quantized.quantized_rnn_relu_cell_dynamic( 1242 input, 1243 hx, 1244 self._packed_weight_ih, 1245 self._packed_weight_hh, 1246 self.bias_ih, 1247 self.bias_hh, 1248 ) 1249 else: 1250 ret = input # TODO: remove when jit supports exception flow 1251 raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") 1252 return ret 1253 1254 @classmethod 1255 def from_float(cls, mod, use_precomputed_fake_quant=False): 1256 return super().from_float( 1257 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 1258 ) 1259 1260 1261class LSTMCell(RNNCellBase): 1262 r"""A long short-term memory (LSTM) cell. 1263 1264 A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. 1265 Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, 1266 please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. 1267 1268 Examples:: 1269 1270 >>> # xdoctest: +SKIP 1271 >>> rnn = nn.LSTMCell(10, 20) 1272 >>> input = torch.randn(6, 3, 10) 1273 >>> hx = torch.randn(3, 20) 1274 >>> cx = torch.randn(3, 20) 1275 >>> output = [] 1276 >>> for i in range(6): 1277 ... hx, cx = rnn(input[i], (hx, cx)) 1278 ... output.append(hx) 1279 """ 1280 1281 def __init__(self, *args, **kwargs): 1282 super().__init__(*args, num_chunks=4, **kwargs) # type: ignore[misc] 1283 1284 def _get_name(self): 1285 return "DynamicQuantizedLSTMCell" 1286 1287 def forward( 1288 self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 1289 ) -> Tuple[Tensor, Tensor]: 1290 self.check_forward_input(input) 1291 if hx is None: 1292 zeros = torch.zeros( 1293 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1294 ) 1295 hx = (zeros, zeros) 1296 self.check_forward_hidden(input, hx[0], "[0]") 1297 self.check_forward_hidden(input, hx[1], "[1]") 1298 return torch.ops.quantized.quantized_lstm_cell_dynamic( 1299 input, 1300 hx, 1301 self._packed_weight_ih, 1302 self._packed_weight_hh, 1303 self.bias_ih, 1304 self.bias_hh, 1305 ) 1306 1307 @classmethod 1308 def from_float(cls, mod, use_precomputed_fake_quant=False): 1309 return super().from_float( 1310 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 1311 ) 1312 1313 1314class GRUCell(RNNCellBase): 1315 r"""A gated recurrent unit (GRU) cell 1316 1317 A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. 1318 Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, 1319 please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. 1320 1321 Examples:: 1322 1323 >>> # xdoctest: +SKIP 1324 >>> rnn = nn.GRUCell(10, 20) 1325 >>> input = torch.randn(6, 3, 10) 1326 >>> hx = torch.randn(3, 20) 1327 >>> output = [] 1328 >>> for i in range(6): 1329 ... hx = rnn(input[i], hx) 1330 ... output.append(hx) 1331 """ 1332 1333 def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8): 1334 super().__init__(input_size, hidden_size, bias, num_chunks=3, dtype=dtype) 1335 1336 def _get_name(self): 1337 return "DynamicQuantizedGRUCell" 1338 1339 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 1340 self.check_forward_input(input) 1341 if hx is None: 1342 hx = torch.zeros( 1343 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1344 ) 1345 self.check_forward_hidden(input, hx, "") 1346 return torch.ops.quantized.quantized_gru_cell_dynamic( 1347 input, 1348 hx, 1349 self._packed_weight_ih, 1350 self._packed_weight_hh, 1351 self.bias_ih, 1352 self.bias_hh, 1353 ) 1354 1355 @classmethod 1356 def from_float(cls, mod, use_precomputed_fake_quant=False): 1357 return super().from_float( 1358 mod, use_precomputed_fake_quant=use_precomputed_fake_quant 1359 ) 1360