1# mypy: allow-untyped-defs 2import warnings 3from typing import Optional, Tuple 4 5import torch 6import torch.jit # this is needed to avoid a circular import 7import torch.nn.functional as F 8from torch import nn, Tensor 9 10 11__all__ = ["MultiheadAttention"] 12 13 14class MultiheadAttention(nn.MultiheadAttention): 15 _FLOAT_MODULE = nn.MultiheadAttention 16 17 r"""Quantizable implementation of the MultiheadAttention. 18 19 Note:: 20 Please, refer to :class:`~torch.nn.MultiheadAttention` for more 21 information 22 23 Allows the model to jointly attend to information from different 24 representation subspaces. 25 See reference: Attention Is All You Need 26 27 The original MHA module is not quantizable. 28 This reimplements it by explicitly instantiating the linear layers. 29 30 .. math:: 31 \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 32 \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 33 34 Args: 35 embed_dim: total dimension of the model. 36 num_heads: parallel attention heads. 37 dropout: a Dropout layer on attn_output_weights. Default: 0.0. 38 bias: add bias as module parameter. Default: True. 39 add_bias_kv: add bias to the key and value sequences at dim=0. 40 add_zero_attn: add a new batch of zeros to the key and 41 value sequences at dim=1. 42 kdim: total number of features in key. Default: None. 43 vdim: total number of features in value. Default: None. 44 batch_first: If ``True``, then the input and output tensors are provided 45 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 46 47 Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set 48 to :attr:`embed_dim` such that query, key, and value have the same 49 number of features. 50 51 Examples:: 52 53 >>> import torch.ao.nn.quantizable as nnqa 54 >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) 55 >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 56 57 Note:: 58 Please, follow the quantization flow to convert the quantizable MHA. 59 """ 60 __constants__ = ["batch_first"] 61 62 def __init__( 63 self, 64 embed_dim: int, 65 num_heads: int, 66 dropout: float = 0.0, 67 bias: bool = True, 68 add_bias_kv: bool = False, 69 add_zero_attn: bool = False, 70 kdim: Optional[int] = None, 71 vdim: Optional[int] = None, 72 batch_first: bool = False, 73 device=None, 74 dtype=None, 75 ) -> None: 76 factory_kwargs = {"device": device, "dtype": dtype} 77 super().__init__( 78 embed_dim, 79 num_heads, 80 dropout, 81 bias, 82 add_bias_kv, 83 add_zero_attn, 84 kdim, 85 vdim, 86 batch_first, 87 **factory_kwargs, 88 ) 89 self.linear_Q = nn.Linear( 90 self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs 91 ) 92 self.linear_K = nn.Linear( 93 self.kdim, self.embed_dim, bias=bias, **factory_kwargs 94 ) 95 self.linear_V = nn.Linear( 96 self.vdim, self.embed_dim, bias=bias, **factory_kwargs 97 ) 98 # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 99 self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment] 100 101 # Functionals 102 self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() 103 # note: importing torch.ao.nn.quantized at top creates a circular import 104 105 # Quant/Dequant 106 self.quant_attn_output = torch.ao.quantization.QuantStub() 107 self.quant_attn_output_weights = torch.ao.quantization.QuantStub() 108 self.dequant_q = torch.ao.quantization.DeQuantStub() 109 self.dequant_k = torch.ao.quantization.DeQuantStub() 110 self.dequant_v = torch.ao.quantization.DeQuantStub() 111 112 def _get_name(self): 113 return "QuantizableMultiheadAttention" 114 115 @classmethod 116 def from_float(cls, other): 117 assert type(other) == cls._FLOAT_MODULE 118 assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" 119 # Setting the dropout to 0.0! 120 observed = cls( 121 other.embed_dim, 122 other.num_heads, 123 other.dropout, 124 (other.in_proj_bias is not None), 125 (other.bias_k is not None), 126 other.add_zero_attn, 127 other.kdim, 128 other.vdim, 129 other.batch_first, 130 ) 131 observed.bias_k = other.bias_k 132 observed.bias_v = other.bias_v 133 observed.qconfig = other.qconfig 134 135 # Set the linear weights 136 # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969 137 observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type] 138 observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type] 139 if other._qkv_same_embed_dim: 140 # Use separate params 141 bias = other.in_proj_bias 142 _start = 0 143 _end = _start + other.embed_dim 144 weight = other.in_proj_weight[_start:_end, :] 145 if bias is not None: 146 bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) 147 observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad) 148 observed.linear_Q.bias = bias 149 150 bias = other.in_proj_bias 151 _start = _end 152 _end = _start + other.embed_dim 153 weight = other.in_proj_weight[_start:_end, :] 154 if bias is not None: 155 bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) 156 observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad) 157 observed.linear_K.bias = bias 158 159 bias = other.in_proj_bias 160 _start = _end 161 weight = other.in_proj_weight[_start:, :] 162 if bias is not None: 163 bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) 164 observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad) 165 observed.linear_V.bias = bias 166 else: 167 observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) 168 observed.linear_K.weight = nn.Parameter(other.k_proj_weight) 169 observed.linear_V.weight = nn.Parameter(other.v_proj_weight) 170 if other.in_proj_bias is None: 171 observed.linear_Q.bias = None # type: ignore[assignment] 172 observed.linear_K.bias = None # type: ignore[assignment] 173 observed.linear_V.bias = None # type: ignore[assignment] 174 else: 175 observed.linear_Q.bias = nn.Parameter( 176 other.in_proj_bias[0 : other.embed_dim] 177 ) 178 observed.linear_K.bias = nn.Parameter( 179 other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)] 180 ) 181 observed.linear_V.bias = nn.Parameter( 182 other.in_proj_bias[(other.embed_dim * 2) :] 183 ) 184 observed.eval() 185 # Explicit prepare 186 observed = torch.ao.quantization.prepare(observed, inplace=True) 187 return observed 188 189 @torch.jit.unused 190 def dequantize(self): 191 r"""Utility to convert the quantized MHA back to float. 192 193 The motivation for this is that it is not trivial to conver the weights 194 from the format that is used in the quantized version back to the 195 float. 196 """ 197 fp = self._FLOAT_MODULE( 198 self.embed_dim, 199 self.num_heads, 200 self.dropout, 201 (self.linear_Q._weight_bias()[1] is not None), 202 (self.bias_k is not None), 203 self.add_zero_attn, 204 self.kdim, 205 self.vdim, 206 self.batch_first, 207 ) 208 assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim 209 if self.bias_k is not None: 210 fp.bias_k = nn.Parameter(self.bias_k.dequantize()) 211 if self.bias_v is not None: 212 fp.bias_v = nn.Parameter(self.bias_v.dequantize()) 213 214 # Set the linear weights 215 # Note: Because the linear layers are quantized, mypy does not nkow how 216 # to deal with them -- might need to ignore the typing checks. 217 # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 218 w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] 219 fp.out_proj.weight = nn.Parameter(w.dequantize()) 220 if b is not None: 221 fp.out_proj.bias = nn.Parameter(b) 222 223 wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] 224 wQ = wQ.dequantize() 225 wK, bK = self.linear_K._weight_bias() # type: ignore[operator] 226 wK = wK.dequantize() 227 wV, bV = self.linear_V._weight_bias() # type: ignore[operator] 228 wV = wV.dequantize() 229 if fp._qkv_same_embed_dim: 230 # Use separate params 231 _start = 0 232 _end = _start + fp.embed_dim 233 fp.in_proj_weight[_start:_end, :] = wQ 234 if fp.in_proj_bias is not None: 235 assert all(bQ == 0) 236 fp.in_proj_bias[_start:_end] = bQ 237 238 _start = _end 239 _end = _start + fp.embed_dim 240 fp.in_proj_weight[_start:_end, :] = wK 241 if fp.in_proj_bias is not None: 242 assert all(bK == 0) 243 fp.in_proj_bias[_start:_end] = bK 244 245 _start = _end 246 fp.in_proj_weight[_start:, :] = wV 247 if fp.in_proj_bias is not None: 248 assert all(bV == 0) 249 fp.in_proj_bias[_start:] = bV 250 else: 251 fp.q_proj_weight = nn.Parameter(wQ) 252 fp.k_proj_weight = nn.Parameter(wK) 253 fp.v_proj_weight = nn.Parameter(wV) 254 if fp.in_proj_bias is None: 255 self.linear_Q.bias = None 256 self.linear_K.bias = None 257 self.linear_V.bias = None 258 else: 259 fp.in_proj_bias[0 : fp.embed_dim] = bQ 260 fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK 261 fp.in_proj_bias[(fp.embed_dim * 2) :] = bV 262 263 return fp 264 265 @classmethod 266 def from_observed(cls, other): 267 # The whole flow is float -> observed -> quantized 268 # This class does float -> observed only 269 # See nn.quantized.MultiheadAttention 270 raise NotImplementedError( 271 "It looks like you are trying to prepare an " 272 "MHA module. Please, see " 273 "the examples on quantizable MHAs." 274 ) 275 276 def forward( 277 self, 278 query: Tensor, 279 key: Tensor, 280 value: Tensor, 281 key_padding_mask: Optional[Tensor] = None, 282 need_weights: bool = True, 283 attn_mask: Optional[Tensor] = None, 284 average_attn_weights: bool = True, 285 is_causal: bool = False, 286 ) -> Tuple[Tensor, Optional[Tensor]]: 287 r""" 288 Note:: 289 Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more 290 information 291 292 Args: 293 query, key, value: map a query and a set of key-value pairs to an output. 294 See "Attention Is All You Need" for more details. 295 key_padding_mask: if provided, specified padding elements in the key will 296 be ignored by the attention. When given a binary mask and a value is True, 297 the corresponding value on the attention layer will be ignored. 298 need_weights: output attn_output_weights. 299 attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 300 the batches while a 3D mask allows to specify a different mask for the entries of each batch. 301 302 Shape: 303 - Inputs: 304 - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 305 the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. 306 - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 307 the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. 308 - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 309 the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. 310 - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 311 If a BoolTensor is provided, the positions with the 312 value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 313 - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 314 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 315 S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked 316 positions. If a BoolTensor is provided, positions with ``True`` 317 is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 318 is provided, it will be added to the attention weight. 319 - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. 320 Default: ``False``. 321 - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across 322 heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an 323 effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) 324 325 - Outputs: 326 - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 327 E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. 328 - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged 329 across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, 330 S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 331 head of shape :math:`(N, num_heads, L, S)`. 332 """ 333 return self._forward_impl( 334 query, 335 key, 336 value, 337 key_padding_mask, 338 need_weights, 339 attn_mask, 340 average_attn_weights, 341 is_causal, 342 ) 343 344 def _forward_impl( 345 self, 346 query: Tensor, 347 key: Tensor, 348 value: Tensor, 349 key_padding_mask: Optional[Tensor] = None, 350 need_weights: bool = True, 351 attn_mask: Optional[Tensor] = None, 352 average_attn_weights: bool = True, 353 is_causal: bool = False, 354 ) -> Tuple[Tensor, Optional[Tensor]]: 355 # This version will not deal with the static key/value pairs. 356 # Keeping it here for future changes. 357 # 358 # TODO: This method has some duplicate lines with the 359 # `torch.nn.functional.multi_head_attention`. Will need to refactor. 360 static_k = None 361 static_v = None 362 363 if attn_mask is not None and is_causal: 364 raise AssertionError("Only allow causal mask or attn_mask") 365 366 if is_causal: 367 raise AssertionError("causal mask not supported by AO MHA module") 368 369 if self.batch_first: 370 query, key, value = (x.transpose(0, 1) for x in (query, key, value)) 371 372 tgt_len, bsz, embed_dim_to_check = query.size() 373 assert self.embed_dim == embed_dim_to_check 374 # allow MHA to have different sizes for the feature dimension 375 assert key.size(0) == value.size(0) and key.size(1) == value.size(1) 376 377 head_dim = self.embed_dim // self.num_heads 378 assert ( 379 head_dim * self.num_heads == self.embed_dim 380 ), "embed_dim must be divisible by num_heads" 381 scaling = float(head_dim) ** -0.5 382 383 q = self.linear_Q(query) 384 k = self.linear_K(key) 385 v = self.linear_V(value) 386 387 q = self.q_scaling_product.mul_scalar(q, scaling) 388 389 if attn_mask is not None: 390 if attn_mask.dtype == torch.uint8: 391 warnings.warn( 392 "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. " 393 "Use bool tensor instead.", 394 stacklevel=3, 395 ) 396 attn_mask = attn_mask.to(torch.bool) 397 assert ( 398 attn_mask.is_floating_point() or attn_mask.dtype == torch.bool 399 ), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}" 400 401 if attn_mask.dim() == 2: 402 attn_mask = attn_mask.unsqueeze(0) 403 if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: 404 raise RuntimeError("The size of the 2D attn_mask is not correct.") 405 elif attn_mask.dim() == 3: 406 if list(attn_mask.size()) != [ 407 bsz * self.num_heads, 408 query.size(0), 409 key.size(0), 410 ]: 411 raise RuntimeError("The size of the 3D attn_mask is not correct.") 412 else: 413 raise RuntimeError( 414 f"attn_mask's dimension {attn_mask.dim()} is not supported" 415 ) 416 # attn_mask's dim is 3 now. 417 418 # convert ByteTensor key_padding_mask to bool 419 if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 420 warnings.warn( 421 "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. " 422 "Use bool tensor instead.", 423 stacklevel=3, 424 ) 425 key_padding_mask = key_padding_mask.to(torch.bool) 426 if self.bias_k is not None and self.bias_v is not None: 427 if static_k is None and static_v is None: 428 # Explicitly assert that bias_k and bias_v are not None 429 # in a way that TorchScript can understand. 430 bias_k = self.bias_k 431 assert bias_k is not None 432 bias_v = self.bias_v 433 assert bias_v is not None 434 435 k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 436 v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 437 if attn_mask is not None: 438 attn_mask = F.pad(attn_mask, (0, 1)) 439 if key_padding_mask is not None: 440 key_padding_mask = F.pad(key_padding_mask, (0, 1)) 441 else: 442 assert static_k is None, "bias cannot be added to static key." 443 assert static_v is None, "bias cannot be added to static value." 444 else: 445 assert self.bias_k is None 446 assert self.bias_v is None 447 448 q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) 449 if k is not None: 450 k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) 451 if v is not None: 452 v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) 453 454 if static_k is not None: 455 assert static_k.size(0) == bsz * self.num_heads 456 assert static_k.size(2) == head_dim 457 k = static_k 458 459 if static_v is not None: 460 assert static_v.size(0) == bsz * self.num_heads 461 assert static_v.size(2) == head_dim 462 v = static_v 463 464 src_len = k.size(1) 465 466 if key_padding_mask is not None: 467 assert key_padding_mask.size(0) == bsz 468 assert key_padding_mask.size(1) == src_len 469 470 if self.add_zero_attn: 471 src_len += 1 472 k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) 473 if k.is_quantized: 474 k_zeros = torch.quantize_per_tensor( 475 k_zeros, k.q_scale(), k.q_zero_point(), k.dtype 476 ) 477 k = torch.cat([k, k_zeros], dim=1) 478 v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) 479 if v.is_quantized: 480 v_zeros = torch.quantize_per_tensor( 481 v_zeros, v.q_scale(), v.q_zero_point(), v.dtype 482 ) 483 v = torch.cat([v, v_zeros], dim=1) 484 485 if attn_mask is not None: 486 attn_mask = F.pad(attn_mask, (0, 1)) 487 if key_padding_mask is not None: 488 key_padding_mask = F.pad(key_padding_mask, (0, 1)) 489 490 # Leaving the quantized zone here 491 q = self.dequant_q(q) 492 k = self.dequant_k(k) 493 v = self.dequant_v(v) 494 attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 495 assert list(attn_output_weights.size()) == [ 496 bsz * self.num_heads, 497 tgt_len, 498 src_len, 499 ] 500 501 if attn_mask is not None: 502 if attn_mask.dtype == torch.bool: 503 attn_output_weights.masked_fill_(attn_mask, float("-inf")) 504 else: 505 attn_output_weights += attn_mask 506 507 if key_padding_mask is not None: 508 attn_output_weights = attn_output_weights.view( 509 bsz, self.num_heads, tgt_len, src_len 510 ) 511 attn_output_weights = attn_output_weights.masked_fill( 512 key_padding_mask.unsqueeze(1).unsqueeze(2), 513 float("-inf"), 514 ) 515 attn_output_weights = attn_output_weights.view( 516 bsz * self.num_heads, tgt_len, src_len 517 ) 518 519 attn_output_weights = F.softmax(attn_output_weights, dim=-1) 520 attn_output_weights = F.dropout( 521 attn_output_weights, p=self.dropout, training=self.training 522 ) 523 524 attn_output = torch.bmm(attn_output_weights, v) 525 assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] 526 if self.batch_first: 527 attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) 528 else: 529 attn_output = ( 530 attn_output.transpose(0, 1) 531 .contiguous() 532 .view(tgt_len, bsz, self.embed_dim) 533 ) 534 535 # Reentering the quantized zone 536 attn_output = self.quant_attn_output(attn_output) 537 # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 538 attn_output = self.out_proj(attn_output) # type: ignore[has-type] 539 attn_output_weights = self.quant_attn_output_weights(attn_output_weights) 540 541 if need_weights: 542 # average attention weights over heads 543 attn_output_weights = attn_output_weights.view( 544 bsz, self.num_heads, tgt_len, src_len 545 ) 546 if average_attn_weights: 547 attn_output_weights = attn_output_weights.mean(dim=1) 548 return attn_output, attn_output_weights 549 else: 550 return attn_output, None 551