xref: /aosp_15_r20/external/pytorch/test/functorch/attn_ft.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import math
7
8import torch
9from functorch.dim import cat, dimlists, dims, softmax
10from torch import nn
11
12
13class Linear(nn.Linear):
14    def forward(self, input):
15        ci, co = dims()
16        b = dimlists()
17        result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
18        return result.order(b, co)
19
20
21class BertSelfAttention(nn.Module):
22    def __init__(
23        self,
24        hidden_size,
25        num_attention_heads,
26        attention_probs_dropout_prob,
27        position_embedding_type=None,
28        max_position_embeddings=None,
29        linear=Linear,
30    ):
31        super().__init__()
32        if hidden_size % num_attention_heads != 0:
33            raise ValueError(
34                f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
35                f"heads ({num_attention_heads})"
36            )
37
38        self.num_attention_heads = num_attention_heads
39        self.attention_head_size = int(hidden_size / num_attention_heads)
40        self.all_head_size = self.num_attention_heads * self.attention_head_size
41
42        self.query = linear(hidden_size, self.all_head_size)
43        self.key = linear(hidden_size, self.all_head_size)
44        self.value = linear(hidden_size, self.all_head_size)
45
46        self.dropout_prob = attention_probs_dropout_prob
47        self.position_embedding_type = position_embedding_type
48
49        if self.position_embedding_type is not None:
50            assert max_position_embeddings is not None
51            self.max_position_embeddings = max_position_embeddings
52            self.distance_embedding = nn.Embedding(
53                2 * max_position_embeddings - 1, self.attention_head_size
54            )
55
56    def forward(
57        self,
58        hidden_states,
59        past_key_value=None,
60    ):
61        # first run the encoding linear layers for q, k, v normally
62        # the meaning of a linear layer is well understood, so no need to use explicit dimensions
63        q = self.query(hidden_states)
64        k = self.key(hidden_states)
65        v = self.value(hidden_states)
66
67        # introduce values that represent each dimension. dimensions are 'first class'
68        # because they are actual python values introduced here
69        batch, query_sequence, key_sequence, heads, features = dims()
70        heads.size = self.num_attention_heads
71
72        # bind the positional dimensions in k, q, and v against
73        # our values. the sizes of each dimension are determined by this binding
74        # and when a dimension is used twice (e.g. batch), its size against both
75        # uses is checked for consistency.
76        # The group (heads, features) splits apart a single positional dimension
77        # into two dimensions. Since heads.size*features.size == q.size(2)
78        # and we specified heads.size, features.size is inferred here.
79        q = q[batch, query_sequence, [heads, features]]
80        k = k[batch, key_sequence, [heads, features]]
81        v = v[batch, key_sequence, [heads, features]]
82
83        # this option allows the model to attend to not just the elements of the current sequence
84        # but the previous elements as well as additional tokens.
85        if past_key_value is not None:
86            extended_key_sequence = dims()
87            key_past = past_key_value[0][batch, heads, key_sequence, features]
88            value_past = past_key_value[1][batch, heads, key_sequence, features]
89            # cat introduces a new dimension extended_key_sequence, because it is twice as long
90            # as the original key_sequence
91            k = cat([key_past, k], key_sequence, extended_key_sequence)
92            v = cat([value_past, v], key_sequence, extended_key_sequence)
93            # for the rest of the function, we will just use extended_key_sequence in lieu of
94            # key_sequence
95            key_sequence = extended_key_sequence
96
97        # Take the dot product between "query" and "key" to get the raw attention scores.
98        # The actual outer-product and summation are explicitly represented here,
99        # and like einsum, will be pattern matched to an efficient matrix multiply op.
100        attention_scores = (q * k).sum(features) / math.sqrt(features.size)
101
102        # relative positional embeddings gave a unique embedding based on the distance between
103        # key and value tokens in the sequence, e.g.
104        #  0  1  2  3
105        # -1  0  1  2
106        # -2 -1  0  1
107        # -3 -2 -1  0
108        if self.position_embedding_type is not None:
109            # the value of a dimension object when used as a tensor is the indices along its dimension
110            # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
111            # with the distance between them
112            distance = query_sequence - key_sequence
113
114            assert key_sequence.size <= self.max_position_embeddings
115
116            # we can then use that as an indirect index into the embedding table values to look up the features for that index
117            # this is just a `gather` primitive op. The resulting tensor will
118            # have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
119            # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
120            # this form of indirect indexing is more straightforward than either advanced indexing or torch.gather which both
121            # have a lot of dependencies on the positions of indexing tensors.
122
123            positional_embedding = self.distance_embedding.weight[
124                self.max_position_embeddings - 1 + distance, features
125            ]
126
127            if self.position_embedding_type == "relative_key":
128                # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
129                # eventhough they are degenerate matmuls
130                relative_position_scores = (q * positional_embedding).sum(features)
131                attention_scores = attention_scores + relative_position_scores
132            elif self.position_embedding_type == "relative_key_query":
133                relative_position_scores_query = (q * positional_embedding).sum(
134                    features
135                )
136                relative_position_scores_key = (k * positional_embedding).sum(features)
137                attention_scores = (
138                    attention_scores
139                    + relative_position_scores_query
140                    + relative_position_scores_key
141                )
142
143        attention_probs = attention_scores
144        # Normalize the attention scores to probabilities.
145        attention_probs = softmax(attention_scores, dim=key_sequence)
146        # # This is actually dropping out entire tokens to attend to, which might
147        # # seem a bit unusual, but is taken from the original Transformer paper.
148        attention_probs = torch.nn.functional.dropout(
149            attention_probs, p=self.dropout_prob
150        )
151
152        # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
153        # we are weighting the values v across all keys with the attention scores.
154        context_layer = (attention_probs * v).sum(key_sequence)
155
156        # finally, we convert back to a standard tensor by describing the layout of dimensions.
157        # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
158        return context_layer.order(batch, query_sequence, [heads, features])
159