xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantizable/modules/activation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from typing import Optional, Tuple
4
5import torch
6import torch.jit  # this is needed to avoid a circular import
7import torch.nn.functional as F
8from torch import nn, Tensor
9
10
11__all__ = ["MultiheadAttention"]
12
13
14class MultiheadAttention(nn.MultiheadAttention):
15    _FLOAT_MODULE = nn.MultiheadAttention
16
17    r"""Quantizable implementation of the MultiheadAttention.
18
19    Note::
20        Please, refer to :class:`~torch.nn.MultiheadAttention` for more
21        information
22
23    Allows the model to jointly attend to information from different
24    representation subspaces.
25    See reference: Attention Is All You Need
26
27    The original MHA module is not quantizable.
28    This reimplements it by explicitly instantiating the linear layers.
29
30    .. math::
31        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
32        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
33
34    Args:
35        embed_dim: total dimension of the model.
36        num_heads: parallel attention heads.
37        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
38        bias: add bias as module parameter. Default: True.
39        add_bias_kv: add bias to the key and value sequences at dim=0.
40        add_zero_attn: add a new batch of zeros to the key and
41                       value sequences at dim=1.
42        kdim: total number of features in key. Default: None.
43        vdim: total number of features in value. Default: None.
44        batch_first: If ``True``, then the input and output tensors are provided
45            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
46
47    Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
48    to :attr:`embed_dim` such that query, key, and value have the same
49    number of features.
50
51    Examples::
52
53        >>> import torch.ao.nn.quantizable as nnqa
54        >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
55        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
56
57    Note::
58        Please, follow the quantization flow to convert the quantizable MHA.
59    """
60    __constants__ = ["batch_first"]
61
62    def __init__(
63        self,
64        embed_dim: int,
65        num_heads: int,
66        dropout: float = 0.0,
67        bias: bool = True,
68        add_bias_kv: bool = False,
69        add_zero_attn: bool = False,
70        kdim: Optional[int] = None,
71        vdim: Optional[int] = None,
72        batch_first: bool = False,
73        device=None,
74        dtype=None,
75    ) -> None:
76        factory_kwargs = {"device": device, "dtype": dtype}
77        super().__init__(
78            embed_dim,
79            num_heads,
80            dropout,
81            bias,
82            add_bias_kv,
83            add_zero_attn,
84            kdim,
85            vdim,
86            batch_first,
87            **factory_kwargs,
88        )
89        self.linear_Q = nn.Linear(
90            self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs
91        )
92        self.linear_K = nn.Linear(
93            self.kdim, self.embed_dim, bias=bias, **factory_kwargs
94        )
95        self.linear_V = nn.Linear(
96            self.vdim, self.embed_dim, bias=bias, **factory_kwargs
97        )
98        # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
99        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)  # type: ignore[assignment]
100
101        # Functionals
102        self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
103        # note: importing torch.ao.nn.quantized at top creates a circular import
104
105        # Quant/Dequant
106        self.quant_attn_output = torch.ao.quantization.QuantStub()
107        self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
108        self.dequant_q = torch.ao.quantization.DeQuantStub()
109        self.dequant_k = torch.ao.quantization.DeQuantStub()
110        self.dequant_v = torch.ao.quantization.DeQuantStub()
111
112    def _get_name(self):
113        return "QuantizableMultiheadAttention"
114
115    @classmethod
116    def from_float(cls, other):
117        assert type(other) == cls._FLOAT_MODULE
118        assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
119        # Setting the dropout to 0.0!
120        observed = cls(
121            other.embed_dim,
122            other.num_heads,
123            other.dropout,
124            (other.in_proj_bias is not None),
125            (other.bias_k is not None),
126            other.add_zero_attn,
127            other.kdim,
128            other.vdim,
129            other.batch_first,
130        )
131        observed.bias_k = other.bias_k
132        observed.bias_v = other.bias_v
133        observed.qconfig = other.qconfig
134
135        # Set the linear weights
136        # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
137        observed.out_proj.weight = other.out_proj.weight  # type: ignore[has-type]
138        observed.out_proj.bias = other.out_proj.bias  # type: ignore[has-type]
139        if other._qkv_same_embed_dim:
140            # Use separate params
141            bias = other.in_proj_bias
142            _start = 0
143            _end = _start + other.embed_dim
144            weight = other.in_proj_weight[_start:_end, :]
145            if bias is not None:
146                bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
147            observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad)
148            observed.linear_Q.bias = bias
149
150            bias = other.in_proj_bias
151            _start = _end
152            _end = _start + other.embed_dim
153            weight = other.in_proj_weight[_start:_end, :]
154            if bias is not None:
155                bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
156            observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad)
157            observed.linear_K.bias = bias
158
159            bias = other.in_proj_bias
160            _start = _end
161            weight = other.in_proj_weight[_start:, :]
162            if bias is not None:
163                bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
164            observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad)
165            observed.linear_V.bias = bias
166        else:
167            observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
168            observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
169            observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
170            if other.in_proj_bias is None:
171                observed.linear_Q.bias = None  # type: ignore[assignment]
172                observed.linear_K.bias = None  # type: ignore[assignment]
173                observed.linear_V.bias = None  # type: ignore[assignment]
174            else:
175                observed.linear_Q.bias = nn.Parameter(
176                    other.in_proj_bias[0 : other.embed_dim]
177                )
178                observed.linear_K.bias = nn.Parameter(
179                    other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)]
180                )
181                observed.linear_V.bias = nn.Parameter(
182                    other.in_proj_bias[(other.embed_dim * 2) :]
183                )
184        observed.eval()
185        # Explicit prepare
186        observed = torch.ao.quantization.prepare(observed, inplace=True)
187        return observed
188
189    @torch.jit.unused
190    def dequantize(self):
191        r"""Utility to convert the quantized MHA back to float.
192
193        The motivation for this is that it is not trivial to conver the weights
194        from the format that is used in the quantized version back to the
195        float.
196        """
197        fp = self._FLOAT_MODULE(
198            self.embed_dim,
199            self.num_heads,
200            self.dropout,
201            (self.linear_Q._weight_bias()[1] is not None),
202            (self.bias_k is not None),
203            self.add_zero_attn,
204            self.kdim,
205            self.vdim,
206            self.batch_first,
207        )
208        assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
209        if self.bias_k is not None:
210            fp.bias_k = nn.Parameter(self.bias_k.dequantize())
211        if self.bias_v is not None:
212            fp.bias_v = nn.Parameter(self.bias_v.dequantize())
213
214        # Set the linear weights
215        # Note: Because the linear layers are quantized, mypy does not nkow how
216        # to deal with them -- might need to ignore the typing checks.
217        # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
218        w, b = self.out_proj._weight_bias()  # type: ignore[operator, has-type]
219        fp.out_proj.weight = nn.Parameter(w.dequantize())
220        if b is not None:
221            fp.out_proj.bias = nn.Parameter(b)
222
223        wQ, bQ = self.linear_Q._weight_bias()  # type: ignore[operator]
224        wQ = wQ.dequantize()
225        wK, bK = self.linear_K._weight_bias()  # type: ignore[operator]
226        wK = wK.dequantize()
227        wV, bV = self.linear_V._weight_bias()  # type: ignore[operator]
228        wV = wV.dequantize()
229        if fp._qkv_same_embed_dim:
230            # Use separate params
231            _start = 0
232            _end = _start + fp.embed_dim
233            fp.in_proj_weight[_start:_end, :] = wQ
234            if fp.in_proj_bias is not None:
235                assert all(bQ == 0)
236                fp.in_proj_bias[_start:_end] = bQ
237
238            _start = _end
239            _end = _start + fp.embed_dim
240            fp.in_proj_weight[_start:_end, :] = wK
241            if fp.in_proj_bias is not None:
242                assert all(bK == 0)
243                fp.in_proj_bias[_start:_end] = bK
244
245            _start = _end
246            fp.in_proj_weight[_start:, :] = wV
247            if fp.in_proj_bias is not None:
248                assert all(bV == 0)
249                fp.in_proj_bias[_start:] = bV
250        else:
251            fp.q_proj_weight = nn.Parameter(wQ)
252            fp.k_proj_weight = nn.Parameter(wK)
253            fp.v_proj_weight = nn.Parameter(wV)
254            if fp.in_proj_bias is None:
255                self.linear_Q.bias = None
256                self.linear_K.bias = None
257                self.linear_V.bias = None
258            else:
259                fp.in_proj_bias[0 : fp.embed_dim] = bQ
260                fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK
261                fp.in_proj_bias[(fp.embed_dim * 2) :] = bV
262
263        return fp
264
265    @classmethod
266    def from_observed(cls, other):
267        # The whole flow is float -> observed -> quantized
268        # This class does float -> observed only
269        # See nn.quantized.MultiheadAttention
270        raise NotImplementedError(
271            "It looks like you are trying to prepare an "
272            "MHA module. Please, see "
273            "the examples on quantizable MHAs."
274        )
275
276    def forward(
277        self,
278        query: Tensor,
279        key: Tensor,
280        value: Tensor,
281        key_padding_mask: Optional[Tensor] = None,
282        need_weights: bool = True,
283        attn_mask: Optional[Tensor] = None,
284        average_attn_weights: bool = True,
285        is_causal: bool = False,
286    ) -> Tuple[Tensor, Optional[Tensor]]:
287        r"""
288        Note::
289            Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
290            information
291
292        Args:
293            query, key, value: map a query and a set of key-value pairs to an output.
294                See "Attention Is All You Need" for more details.
295            key_padding_mask: if provided, specified padding elements in the key will
296                be ignored by the attention. When given a binary mask and a value is True,
297                the corresponding value on the attention layer will be ignored.
298            need_weights: output attn_output_weights.
299            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
300                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
301
302        Shape:
303            - Inputs:
304            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
305              the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
306            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
307              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
308            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
309              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
310            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
311              If a BoolTensor is provided, the positions with the
312              value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
313            - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
314              3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
315              S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
316              positions. If a BoolTensor is provided, positions with ``True``
317              is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
318              is provided, it will be added to the attention weight.
319            - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
320              Default: ``False``.
321            - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
322              heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
323              effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
324
325            - Outputs:
326            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
327              E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
328            - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
329              across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
330              S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
331              head of shape :math:`(N, num_heads, L, S)`.
332        """
333        return self._forward_impl(
334            query,
335            key,
336            value,
337            key_padding_mask,
338            need_weights,
339            attn_mask,
340            average_attn_weights,
341            is_causal,
342        )
343
344    def _forward_impl(
345        self,
346        query: Tensor,
347        key: Tensor,
348        value: Tensor,
349        key_padding_mask: Optional[Tensor] = None,
350        need_weights: bool = True,
351        attn_mask: Optional[Tensor] = None,
352        average_attn_weights: bool = True,
353        is_causal: bool = False,
354    ) -> Tuple[Tensor, Optional[Tensor]]:
355        # This version will not deal with the static key/value pairs.
356        # Keeping it here for future changes.
357        #
358        # TODO: This method has some duplicate lines with the
359        # `torch.nn.functional.multi_head_attention`. Will need to refactor.
360        static_k = None
361        static_v = None
362
363        if attn_mask is not None and is_causal:
364            raise AssertionError("Only allow causal mask or attn_mask")
365
366        if is_causal:
367            raise AssertionError("causal mask not supported by AO MHA module")
368
369        if self.batch_first:
370            query, key, value = (x.transpose(0, 1) for x in (query, key, value))
371
372        tgt_len, bsz, embed_dim_to_check = query.size()
373        assert self.embed_dim == embed_dim_to_check
374        # allow MHA to have different sizes for the feature dimension
375        assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
376
377        head_dim = self.embed_dim // self.num_heads
378        assert (
379            head_dim * self.num_heads == self.embed_dim
380        ), "embed_dim must be divisible by num_heads"
381        scaling = float(head_dim) ** -0.5
382
383        q = self.linear_Q(query)
384        k = self.linear_K(key)
385        v = self.linear_V(value)
386
387        q = self.q_scaling_product.mul_scalar(q, scaling)
388
389        if attn_mask is not None:
390            if attn_mask.dtype == torch.uint8:
391                warnings.warn(
392                    "Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. "
393                    "Use bool tensor instead.",
394                    stacklevel=3,
395                )
396                attn_mask = attn_mask.to(torch.bool)
397            assert (
398                attn_mask.is_floating_point() or attn_mask.dtype == torch.bool
399            ), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}"
400
401            if attn_mask.dim() == 2:
402                attn_mask = attn_mask.unsqueeze(0)
403                if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
404                    raise RuntimeError("The size of the 2D attn_mask is not correct.")
405            elif attn_mask.dim() == 3:
406                if list(attn_mask.size()) != [
407                    bsz * self.num_heads,
408                    query.size(0),
409                    key.size(0),
410                ]:
411                    raise RuntimeError("The size of the 3D attn_mask is not correct.")
412            else:
413                raise RuntimeError(
414                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
415                )
416            # attn_mask's dim is 3 now.
417
418        # convert ByteTensor key_padding_mask to bool
419        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
420            warnings.warn(
421                "Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. "
422                "Use bool tensor instead.",
423                stacklevel=3,
424            )
425            key_padding_mask = key_padding_mask.to(torch.bool)
426        if self.bias_k is not None and self.bias_v is not None:
427            if static_k is None and static_v is None:
428                # Explicitly assert that bias_k and bias_v are not None
429                # in a way that TorchScript can understand.
430                bias_k = self.bias_k
431                assert bias_k is not None
432                bias_v = self.bias_v
433                assert bias_v is not None
434
435                k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
436                v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
437                if attn_mask is not None:
438                    attn_mask = F.pad(attn_mask, (0, 1))
439                if key_padding_mask is not None:
440                    key_padding_mask = F.pad(key_padding_mask, (0, 1))
441            else:
442                assert static_k is None, "bias cannot be added to static key."
443                assert static_v is None, "bias cannot be added to static value."
444        else:
445            assert self.bias_k is None
446            assert self.bias_v is None
447
448        q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
449        if k is not None:
450            k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
451        if v is not None:
452            v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
453
454        if static_k is not None:
455            assert static_k.size(0) == bsz * self.num_heads
456            assert static_k.size(2) == head_dim
457            k = static_k
458
459        if static_v is not None:
460            assert static_v.size(0) == bsz * self.num_heads
461            assert static_v.size(2) == head_dim
462            v = static_v
463
464        src_len = k.size(1)
465
466        if key_padding_mask is not None:
467            assert key_padding_mask.size(0) == bsz
468            assert key_padding_mask.size(1) == src_len
469
470        if self.add_zero_attn:
471            src_len += 1
472            k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
473            if k.is_quantized:
474                k_zeros = torch.quantize_per_tensor(
475                    k_zeros, k.q_scale(), k.q_zero_point(), k.dtype
476                )
477            k = torch.cat([k, k_zeros], dim=1)
478            v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
479            if v.is_quantized:
480                v_zeros = torch.quantize_per_tensor(
481                    v_zeros, v.q_scale(), v.q_zero_point(), v.dtype
482                )
483            v = torch.cat([v, v_zeros], dim=1)
484
485            if attn_mask is not None:
486                attn_mask = F.pad(attn_mask, (0, 1))
487            if key_padding_mask is not None:
488                key_padding_mask = F.pad(key_padding_mask, (0, 1))
489
490        # Leaving the quantized zone here
491        q = self.dequant_q(q)
492        k = self.dequant_k(k)
493        v = self.dequant_v(v)
494        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
495        assert list(attn_output_weights.size()) == [
496            bsz * self.num_heads,
497            tgt_len,
498            src_len,
499        ]
500
501        if attn_mask is not None:
502            if attn_mask.dtype == torch.bool:
503                attn_output_weights.masked_fill_(attn_mask, float("-inf"))
504            else:
505                attn_output_weights += attn_mask
506
507        if key_padding_mask is not None:
508            attn_output_weights = attn_output_weights.view(
509                bsz, self.num_heads, tgt_len, src_len
510            )
511            attn_output_weights = attn_output_weights.masked_fill(
512                key_padding_mask.unsqueeze(1).unsqueeze(2),
513                float("-inf"),
514            )
515            attn_output_weights = attn_output_weights.view(
516                bsz * self.num_heads, tgt_len, src_len
517            )
518
519        attn_output_weights = F.softmax(attn_output_weights, dim=-1)
520        attn_output_weights = F.dropout(
521            attn_output_weights, p=self.dropout, training=self.training
522        )
523
524        attn_output = torch.bmm(attn_output_weights, v)
525        assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
526        if self.batch_first:
527            attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
528        else:
529            attn_output = (
530                attn_output.transpose(0, 1)
531                .contiguous()
532                .view(tgt_len, bsz, self.embed_dim)
533            )
534
535        # Reentering the quantized zone
536        attn_output = self.quant_attn_output(attn_output)
537        # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
538        attn_output = self.out_proj(attn_output)  # type: ignore[has-type]
539        attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
540
541        if need_weights:
542            # average attention weights over heads
543            attn_output_weights = attn_output_weights.view(
544                bsz, self.num_heads, tgt_len, src_len
545            )
546            if average_attn_weights:
547                attn_output_weights = attn_output_weights.mean(dim=1)
548            return attn_output, attn_output_weights
549        else:
550            return attn_output, None
551