1# mypy: allow-untyped-defs 2from typing import Any, Dict, Optional, Tuple 3 4import torch 5import torch.nn as nn 6from torch import _VF, Tensor 7from torch.nn.utils.rnn import PackedSequence 8 9from .utils import _quantize_and_dequantize_weight, _quantize_weight 10 11 12__all__ = [ 13 "RNNCellBase", 14 "RNNCell", 15 "LSTMCell", 16 "GRUCell", 17 "RNNBase", 18 "LSTM", 19 "GRU", 20 "get_quantized_weight", 21] 22 23 24def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: 25 return tensor.index_select(dim, permutation) 26 27 28def _get_weight_and_quantization_params(module, wn): 29 weight = getattr(module, wn) 30 params = [weight] 31 for param_name in [ 32 wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"] 33 ]: 34 if hasattr(module, param_name): 35 param = getattr(module, param_name) 36 else: 37 param = None 38 params.append(param) 39 return params 40 41 42def get_quantized_weight(module, wn): 43 if not hasattr(module, wn): 44 return None 45 params = _get_weight_and_quantization_params(module, wn) 46 weight = _quantize_weight(*params) 47 return weight 48 49 50def _get_quantize_and_dequantized_weight(module, wn): 51 if not hasattr(module, wn): 52 return None 53 params = _get_weight_and_quantization_params(module, wn) 54 weight = _quantize_and_dequantize_weight(*params) 55 return weight 56 57 58class RNNCellBase(nn.RNNCellBase): 59 def __init__( 60 self, 61 input_size: int, 62 hidden_size: int, 63 bias: bool, 64 num_chunks: int, 65 device=None, 66 dtype=None, 67 weight_qparams_dict=None, 68 ) -> None: 69 super().__init__( 70 input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype 71 ) 72 # TODO(jerryzh168): maybe make this arg a required arg 73 if weight_qparams_dict is None: 74 weight_qparams = { 75 "qscheme": torch.per_tensor_affine, 76 "dtype": torch.quint8, 77 "scale": 1.0, 78 "zero_point": 0, 79 } 80 weight_qparams_dict = { 81 "weight_ih": weight_qparams, 82 "weight_hh": weight_qparams, 83 "is_decomposed": False, 84 } 85 assert ( 86 len(weight_qparams_dict) == 3 87 ), "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" 88 self._init_weight_qparams_dict(weight_qparams_dict, device) 89 90 def _init_weight_qparams_dict(self, weight_qparams_dict, device): 91 assert weight_qparams_dict is not None 92 self.is_decomposed = weight_qparams_dict["is_decomposed"] 93 for key, weight_qparams in weight_qparams_dict.items(): 94 if key == "is_decomposed": 95 continue 96 # TODO: refactor the duplicated code to utils.py 97 weight_qscheme = weight_qparams["qscheme"] 98 weight_dtype = weight_qparams["dtype"] 99 setattr(self, key + "_qscheme", weight_qscheme) 100 setattr(self, key + "_dtype", weight_dtype) 101 assert weight_qscheme in [ 102 None, 103 torch.per_tensor_affine, 104 torch.per_channel_affine, 105 ], Exception( 106 f"qscheme: {weight_qscheme} is not support in {self._get_name()}" 107 ) 108 if weight_qscheme is not None: 109 scale = weight_qparams["scale"] 110 scale_tensor = ( 111 scale.clone().detach() 112 if isinstance(scale, torch.Tensor) 113 else torch.tensor(scale, dtype=torch.float, device=device) 114 ) 115 self.register_buffer(key + "_scale", scale_tensor) 116 zp = weight_qparams["zero_point"] 117 zp_tensor = ( 118 zp.clone().detach() 119 if isinstance(zp, torch.Tensor) 120 else torch.tensor(zp, dtype=torch.int, device=device) 121 ) 122 self.register_buffer(key + "_zero_point", zp_tensor) 123 if weight_qscheme == torch.per_channel_affine: 124 axis = weight_qparams["axis"] 125 axis_tensor = ( 126 axis.clone().detach() 127 if isinstance(axis, torch.Tensor) 128 else torch.tensor(axis, dtype=torch.int, device=device) 129 ) 130 self.register_buffer(key + "_axis", axis_tensor) 131 else: 132 # added for TorchScriptability, not used 133 self.register_buffer( 134 key + "_axis", torch.tensor(0, dtype=torch.int, device=device) 135 ) 136 setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) 137 138 def _get_name(self): 139 return "QuantizedRNNCellBase(Reference)" 140 141 def get_quantized_weight_ih(self): 142 return get_quantized_weight(self, "weight_ih") 143 144 def get_quantized_weight_hh(self): 145 return get_quantized_weight(self, "weight_hh") 146 147 def get_weight_ih(self): 148 return _get_quantize_and_dequantized_weight(self, "weight_ih") 149 150 def get_weight_hh(self): 151 return _get_quantize_and_dequantized_weight(self, "weight_hh") 152 153 154class RNNCell(RNNCellBase): 155 """ 156 We'll store weight_qparams for all the weights (weight_ih and weight_hh), 157 we need to pass in a `weight_qparams_dict` that maps from weight name, 158 e.g. weight_ih, to the weight_qparams for that weight 159 """ 160 161 def __init__( 162 self, 163 input_size: int, 164 hidden_size: int, 165 bias: bool = True, 166 nonlinearity: str = "tanh", 167 device=None, 168 dtype=None, 169 weight_qparams_dict: Optional[Dict[str, Any]] = None, 170 ) -> None: 171 factory_kwargs = { 172 "device": device, 173 "dtype": dtype, 174 "weight_qparams_dict": weight_qparams_dict, 175 } 176 super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) 177 self.nonlinearity = nonlinearity 178 179 def _get_name(self): 180 return "QuantizedRNNCell(Reference)" 181 182 # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input 183 # and remove duplicated code, same for the other two Cell modules 184 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 185 assert input.dim() in ( 186 1, 187 2, 188 ), f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" 189 is_batched = input.dim() == 2 190 if not is_batched: 191 input = input.unsqueeze(0) 192 193 if hx is None: 194 hx = torch.zeros( 195 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 196 ) 197 else: 198 hx = hx.unsqueeze(0) if not is_batched else hx 199 200 if self.nonlinearity == "tanh": 201 ret = _VF.rnn_tanh_cell( 202 input, 203 hx, 204 self.get_weight_ih(), 205 self.get_weight_hh(), 206 self.bias_ih, 207 self.bias_hh, 208 ) 209 elif self.nonlinearity == "relu": 210 ret = _VF.rnn_relu_cell( 211 input, 212 hx, 213 self.get_weight_ih(), 214 self.get_weight_hh(), 215 self.bias_ih, 216 self.bias_hh, 217 ) 218 else: 219 ret = input # TODO: remove when jit supports exception flow 220 raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") 221 222 if not is_batched: 223 ret = ret.squeeze(0) 224 225 return ret 226 227 @classmethod 228 def from_float(cls, mod, weight_qparams_dict): 229 ref_mod = cls( 230 mod.input_size, 231 mod.hidden_size, 232 mod.bias, 233 mod.nonlinearity, 234 mod.weight_ih.device, 235 mod.weight_ih.dtype, 236 weight_qparams_dict, 237 ) 238 ref_mod.weight_ih = mod.weight_ih 239 ref_mod.weight_hh = mod.weight_hh 240 ref_mod.bias_ih = mod.bias_ih 241 ref_mod.bias_hh = mod.bias_hh 242 return ref_mod 243 244 245class LSTMCell(RNNCellBase): 246 """ 247 We'll store weight_qparams for all the weights (weight_ih and weight_hh), 248 we need to pass in a `weight_qparams_dict` that maps from weight name, 249 e.g. weight_ih, to the weight_qparams for that weight 250 """ 251 252 def __init__( 253 self, 254 input_size: int, 255 hidden_size: int, 256 bias: bool = True, 257 device=None, 258 dtype=None, 259 weight_qparams_dict: Optional[Dict[str, Any]] = None, 260 ) -> None: 261 factory_kwargs = { 262 "device": device, 263 "dtype": dtype, 264 "weight_qparams_dict": weight_qparams_dict, 265 } 266 super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) 267 268 def _get_name(self): 269 return "QuantizedLSTMCell(Reference)" 270 271 def forward( 272 self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 273 ) -> Tuple[Tensor, Tensor]: 274 assert input.dim() in ( 275 1, 276 2, 277 ), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" 278 is_batched = input.dim() == 2 279 if not is_batched: 280 input = input.unsqueeze(0) 281 282 if hx is None: 283 zeros = torch.zeros( 284 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 285 ) 286 hx = (zeros, zeros) 287 else: 288 hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx 289 290 ret = _VF.lstm_cell( 291 input, 292 hx, 293 self.get_weight_ih(), 294 self.get_weight_hh(), 295 self.bias_ih, 296 self.bias_hh, 297 ) 298 299 if not is_batched: 300 ret = (ret[0].squeeze(0), ret[1].squeeze(0)) 301 return ret 302 303 @classmethod 304 def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False): 305 ref_mod = cls( 306 mod.input_size, 307 mod.hidden_size, 308 mod.bias, 309 mod.weight_ih.device, 310 mod.weight_ih.dtype, 311 weight_qparams_dict, 312 ) 313 ref_mod.weight_ih = mod.weight_ih 314 ref_mod.weight_hh = mod.weight_hh 315 ref_mod.bias_ih = mod.bias_ih 316 ref_mod.bias_hh = mod.bias_hh 317 return ref_mod 318 319 320class GRUCell(RNNCellBase): 321 """ 322 We'll store weight_qparams for all the weights (weight_ih and weight_hh), 323 we need to pass in a `weight_qparams_dict` that maps from weight name, 324 e.g. weight_ih, to the weight_qparams for that weight 325 """ 326 327 def __init__( 328 self, 329 input_size: int, 330 hidden_size: int, 331 bias: bool = True, 332 device=None, 333 dtype=None, 334 weight_qparams_dict: Optional[Dict[str, Any]] = None, 335 ) -> None: 336 factory_kwargs = { 337 "device": device, 338 "dtype": dtype, 339 "weight_qparams_dict": weight_qparams_dict, 340 } 341 super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) 342 343 def _get_name(self): 344 return "QuantizedGRUCell(Reference)" 345 346 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 347 assert input.dim() in ( 348 1, 349 2, 350 ), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" 351 is_batched = input.dim() == 2 352 if not is_batched: 353 input = input.unsqueeze(0) 354 355 if hx is None: 356 hx = torch.zeros( 357 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 358 ) 359 else: 360 hx = hx.unsqueeze(0) if not is_batched else hx 361 362 ret = _VF.gru_cell( 363 input, 364 hx, 365 self.get_weight_ih(), 366 self.get_weight_hh(), 367 self.bias_ih, 368 self.bias_hh, 369 ) 370 371 if not is_batched: 372 ret = ret.squeeze(0) 373 374 return ret 375 376 @classmethod 377 def from_float(cls, mod, weight_qparams_dict): 378 ref_mod = cls( 379 mod.input_size, 380 mod.hidden_size, 381 mod.bias, 382 mod.weight_ih.device, 383 mod.weight_ih.dtype, 384 weight_qparams_dict, 385 ) 386 ref_mod.weight_ih = mod.weight_ih 387 ref_mod.weight_hh = mod.weight_hh 388 ref_mod.bias_ih = mod.bias_ih 389 ref_mod.bias_hh = mod.bias_hh 390 return ref_mod 391 392 393class RNNBase(nn.RNNBase): 394 def __init__( 395 self, 396 mode: str, 397 input_size: int, 398 hidden_size: int, 399 num_layers: int = 1, 400 bias: bool = True, 401 batch_first: bool = False, 402 dropout: float = 0.0, 403 bidirectional: bool = False, 404 proj_size: int = 0, 405 device=None, 406 dtype=None, 407 weight_qparams_dict: Optional[Dict[str, Any]] = None, 408 ) -> None: 409 super().__init__( 410 mode, 411 input_size, 412 hidden_size, 413 num_layers, 414 bias, 415 batch_first, 416 dropout, 417 bidirectional, 418 proj_size, 419 device, 420 dtype, 421 ) 422 # TODO(jerryzh168): maybe make this arg a required arg 423 if weight_qparams_dict is None: 424 weight_qparams = { 425 "qscheme": torch.per_tensor_affine, 426 "dtype": torch.quint8, 427 "scale": 1.0, 428 "zero_point": 0, 429 } 430 weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] 431 for wn in self._flat_weights_names: 432 if wn.startswith("weight"): 433 weight_qparams_dict[wn] = weight_qparams 434 self._init_weight_qparams_dict(weight_qparams_dict, device) 435 436 def _init_weight_qparams_dict(self, weight_qparams_dict, device): 437 self.is_decomposed = weight_qparams_dict["is_decomposed"] 438 for key, weight_qparams in weight_qparams_dict.items(): 439 if key == "is_decomposed": 440 continue 441 weight_qscheme = weight_qparams["qscheme"] 442 weight_dtype = weight_qparams["dtype"] 443 setattr(self, key + "_qscheme", weight_qscheme) 444 setattr(self, key + "_dtype", weight_dtype) 445 assert weight_qscheme in [ 446 None, 447 torch.per_tensor_affine, 448 torch.per_channel_affine, 449 ], Exception( 450 f"qscheme: {weight_qscheme} is not support in {self._get_name()}" 451 ) 452 if weight_qscheme is not None: 453 self.register_buffer( 454 key + "_scale", 455 torch.tensor( 456 weight_qparams["scale"], dtype=torch.float, device=device 457 ), 458 ) 459 self.register_buffer( 460 key + "_zero_point", 461 torch.tensor( 462 weight_qparams["zero_point"], dtype=torch.int, device=device 463 ), 464 ) 465 if weight_qscheme == torch.per_channel_affine: 466 self.register_buffer( 467 key + "_axis", 468 torch.tensor( 469 weight_qparams["axis"], dtype=torch.int, device=device 470 ), 471 ) 472 else: 473 # added for TorchScriptability, not used 474 self.register_buffer( 475 key + "_axis", torch.tensor(0, dtype=torch.int, device=device) 476 ) 477 setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) 478 479 480class LSTM(RNNBase): 481 """Reference Quantized LSTM Module 482 We'll store weight_qparams for all the weights in _flat_weights, we need to pass in 483 a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, 484 to the weight_qparams for that weight 485 """ 486 487 def __init__(self, *args, **kwargs): 488 super().__init__("LSTM", *args, **kwargs) 489 490 # Same as above, see torch/nn/modules/module.py::_forward_unimplemented 491 def permute_hidden( # type: ignore[override] 492 self, 493 hx: Tuple[Tensor, Tensor], 494 permutation: Optional[Tensor], 495 ) -> Tuple[Tensor, Tensor]: 496 if permutation is None: 497 return hx 498 return _apply_permutation(hx[0], permutation), _apply_permutation( 499 hx[1], permutation 500 ) 501 502 def get_expected_cell_size( 503 self, input: Tensor, batch_sizes: Optional[Tensor] 504 ) -> Tuple[int, int, int]: 505 if batch_sizes is not None: 506 mini_batch = int(batch_sizes[0]) 507 else: 508 mini_batch = input.size(0) if self.batch_first else input.size(1) 509 num_directions = 2 if self.bidirectional else 1 510 expected_hidden_size = ( 511 self.num_layers * num_directions, 512 mini_batch, 513 self.hidden_size, 514 ) 515 return expected_hidden_size 516 517 # In the future, we should prevent mypy from applying contravariance rules here. 518 # See torch/nn/modules/module.py::_forward_unimplemented 519 def check_forward_args( # type: ignore[override] 520 self, 521 input: Tensor, 522 hidden: Tuple[Tensor, Tensor], 523 batch_sizes: Optional[Tensor], 524 ): 525 self.check_input(input, batch_sizes) 526 self.check_hidden_size( 527 hidden[0], 528 self.get_expected_hidden_size(input, batch_sizes), 529 "Expected hidden[0] size {}, got {}", 530 ) 531 self.check_hidden_size( 532 hidden[1], 533 self.get_expected_cell_size(input, batch_sizes), 534 "Expected hidden[1] size {}, got {}", 535 ) 536 537 def get_quantized_weight_bias_dict(self): 538 """dictionary from flat_weight_name to quantized weight or (unquantized) bias 539 e.g. 540 { 541 "weight_ih_l0": quantized_weight, 542 "bias_ih_l0": unquantized_bias, 543 ... 544 } 545 """ 546 quantized_weight_bias_dict = {} 547 for wn in self._flat_weights_names: 548 if hasattr(self, wn): 549 if wn.startswith("weight"): 550 weight_or_bias = get_quantized_weight(self, wn) 551 else: 552 weight_or_bias = getattr(self, wn) 553 else: 554 weight_or_bias = None 555 quantized_weight_bias_dict[wn] = weight_or_bias 556 return quantized_weight_bias_dict 557 558 def get_flat_weights(self): 559 flat_weights = [] 560 for wn in self._flat_weights_names: 561 if hasattr(self, wn): 562 weight = getattr(self, wn) 563 if wn.startswith("weight"): 564 params = _get_weight_and_quantization_params(self, wn) 565 weight = _quantize_and_dequantize_weight(*params) 566 else: 567 weight = None 568 flat_weights.append(weight) 569 return flat_weights 570 571 def forward(self, input, hx=None): # noqa: F811 572 orig_input = input 573 # xxx: isinstance check needs to be in conditional for TorchScript to compile 574 batch_sizes = None 575 if isinstance(orig_input, PackedSequence): 576 input, batch_sizes, sorted_indices, unsorted_indices = input 577 max_batch_size = int(batch_sizes[0]) 578 else: 579 batch_sizes = None 580 is_batched = input.dim() == 3 581 batch_dim = 0 if self.batch_first else 1 582 if not is_batched: 583 input = input.unsqueeze(batch_dim) 584 max_batch_size = input.size(0) if self.batch_first else input.size(1) 585 sorted_indices = None 586 unsorted_indices = None 587 588 if hx is None: 589 num_directions = 2 if self.bidirectional else 1 590 real_hidden_size = ( 591 self.proj_size if self.proj_size > 0 else self.hidden_size 592 ) 593 h_zeros = torch.zeros( 594 self.num_layers * num_directions, 595 max_batch_size, 596 real_hidden_size, 597 dtype=input.dtype, 598 device=input.device, 599 ) 600 c_zeros = torch.zeros( 601 self.num_layers * num_directions, 602 max_batch_size, 603 self.hidden_size, 604 dtype=input.dtype, 605 device=input.device, 606 ) 607 hx = (h_zeros, c_zeros) 608 else: 609 if batch_sizes is None: # If not PackedSequence input. 610 if is_batched: # type: ignore[possibly-undefined] 611 if hx[0].dim() != 3 or hx[1].dim() != 3: 612 msg = ( 613 "For batched 3-D input, hx and cx should " 614 f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" 615 ) 616 raise RuntimeError(msg) 617 else: 618 if hx[0].dim() != 2 or hx[1].dim() != 2: 619 msg = ( 620 "For unbatched 2-D input, hx and cx should " 621 f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" 622 ) 623 raise RuntimeError(msg) 624 hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) 625 626 # Each batch of the hidden state should match the input sequence that 627 # the user believes he/she is passing in. 628 hx = self.permute_hidden(hx, sorted_indices) 629 630 self.check_forward_args(input, hx, batch_sizes) 631 if batch_sizes is None: 632 result = _VF.lstm( 633 input, 634 hx, 635 self.get_flat_weights(), 636 self.bias, 637 self.num_layers, 638 self.dropout, 639 self.training, 640 self.bidirectional, 641 self.batch_first, 642 ) 643 else: 644 result = _VF.lstm( 645 input, 646 batch_sizes, 647 hx, 648 self.get_flat_weights(), 649 self.bias, 650 self.num_layers, 651 self.dropout, 652 self.training, 653 self.bidirectional, 654 ) 655 output = result[0] 656 hidden = result[1:] 657 # xxx: isinstance check needs to be in conditional for TorchScript to compile 658 if isinstance(orig_input, PackedSequence): 659 output_packed = PackedSequence( 660 output, batch_sizes, sorted_indices, unsorted_indices 661 ) 662 return output_packed, self.permute_hidden(hidden, unsorted_indices) 663 else: 664 if not is_batched: # type: ignore[possibly-undefined] 665 output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] 666 hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) 667 return output, self.permute_hidden(hidden, unsorted_indices) 668 669 def _get_name(self): 670 return "QuantizedLSTM(Reference)" 671 672 @classmethod 673 def from_float(cls, mod, weight_qparams_dict): 674 ref_mod = cls( 675 mod.input_size, 676 mod.hidden_size, 677 mod.num_layers, 678 mod.bias, 679 mod.batch_first, 680 mod.dropout, 681 mod.bidirectional, 682 weight_qparams_dict=weight_qparams_dict, 683 ) 684 for wn in mod._flat_weights_names: 685 setattr(ref_mod, wn, getattr(mod, wn)) 686 return ref_mod 687 688 689class GRU(RNNBase): 690 """Reference Quantized GRU Module 691 We'll store weight_qparams for all the weights in _flat_weights, we need to pass in 692 a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, 693 to the weight_qparams for that weight 694 """ 695 696 def __init__(self, *args, **kwargs): 697 if "proj_size" in kwargs: 698 raise ValueError( 699 "proj_size argument is only supported for LSTM, not RNN or GRU" 700 ) 701 super().__init__("GRU", *args, **kwargs) 702 703 def get_quantized_weight_bias_dict(self): 704 """dictionary from flat_weight_name to quantized weight or (unquantized) bias 705 e.g. 706 { 707 "weight_ih_l0": quantized_weight, 708 "bias_ih_l0": unquantized_bias, 709 ... 710 } 711 """ 712 quantized_weight_bias_dict = {} 713 for wn in self._flat_weights_names: 714 if hasattr(self, wn): 715 if wn.startswith("weight"): 716 weight_or_bias = get_quantized_weight(self, wn) 717 else: 718 weight_or_bias = getattr(self, wn) 719 else: 720 weight_or_bias = None 721 quantized_weight_bias_dict[wn] = weight_or_bias 722 return quantized_weight_bias_dict 723 724 def get_flat_weights(self): 725 flat_weights = [] 726 for wn in self._flat_weights_names: 727 if hasattr(self, wn): 728 weight = getattr(self, wn) 729 if wn.startswith("weight"): 730 params = _get_weight_and_quantization_params(self, wn) 731 weight = _quantize_and_dequantize_weight(*params) 732 else: 733 weight = None 734 flat_weights.append(weight) 735 return flat_weights 736 737 def forward(self, input, hx=None): # noqa: F811 738 # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py 739 # only changed self._flat_weights to self.get_flat_weights() 740 # TODO: maybe we can try inheriting from that class and define get_flat_weights 741 # as a @property? this might interfere with TorchScript, if we remove that 742 # requirement in the future we should be able to do this 743 orig_input = input 744 # xxx: isinstance check needs to be in conditional for TorchScript to compile 745 if isinstance(orig_input, PackedSequence): 746 input, batch_sizes, sorted_indices, unsorted_indices = input 747 max_batch_size = int(batch_sizes[0]) 748 else: 749 batch_sizes = None 750 assert input.dim() in ( 751 2, 752 3, 753 ), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" 754 is_batched = input.dim() == 3 755 batch_dim = 0 if self.batch_first else 1 756 if not is_batched: 757 input = input.unsqueeze(batch_dim) 758 if hx is not None: 759 if hx.dim() != 2: 760 raise RuntimeError( 761 f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" 762 ) 763 hx = hx.unsqueeze(1) 764 else: 765 if hx is not None and hx.dim() != 3: 766 raise RuntimeError( 767 f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" 768 ) 769 max_batch_size = input.size(0) if self.batch_first else input.size(1) 770 sorted_indices = None 771 unsorted_indices = None 772 773 if hx is None: 774 num_directions = 2 if self.bidirectional else 1 775 hx = torch.zeros( 776 self.num_layers * num_directions, 777 max_batch_size, 778 self.hidden_size, 779 dtype=input.dtype, 780 device=input.device, 781 ) 782 else: 783 # Each batch of the hidden state should match the input sequence that 784 # the user believes he/she is passing in. 785 hx = self.permute_hidden(hx, sorted_indices) 786 787 self.check_forward_args(input, hx, batch_sizes) 788 if batch_sizes is None: 789 result = _VF.gru( 790 input, 791 hx, 792 self.get_flat_weights(), 793 self.bias, 794 self.num_layers, 795 self.dropout, 796 self.training, 797 self.bidirectional, 798 self.batch_first, 799 ) 800 else: 801 result = _VF.gru( 802 input, 803 batch_sizes, 804 hx, 805 self.get_flat_weights(), 806 self.bias, 807 self.num_layers, 808 self.dropout, 809 self.training, 810 self.bidirectional, 811 ) 812 output = result[0] 813 hidden = result[1] 814 815 # xxx: isinstance check needs to be in conditional for TorchScript to compile 816 if isinstance(orig_input, PackedSequence): 817 output_packed = PackedSequence( 818 output, batch_sizes, sorted_indices, unsorted_indices 819 ) 820 return output_packed, self.permute_hidden(hidden, unsorted_indices) 821 else: 822 if not is_batched: # type: ignore[possibly-undefined] 823 output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] 824 hidden = hidden.squeeze(1) 825 826 return output, self.permute_hidden(hidden, unsorted_indices) 827 828 def _get_name(self): 829 return "QuantizedGRU(Reference)" 830 831 @classmethod 832 def from_float(cls, mod, weight_qparams_dict): 833 ref_mod = cls( 834 mod.input_size, 835 mod.hidden_size, 836 mod.num_layers, 837 mod.bias, 838 mod.batch_first, 839 mod.dropout, 840 mod.bidirectional, 841 weight_qparams_dict=weight_qparams_dict, 842 ) 843 for wn in mod._flat_weights_names: 844 setattr(ref_mod, wn, getattr(mod, wn)) 845 return ref_mod 846