1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import math 7 8import torch 9from functorch.dim import cat, dimlists, dims, softmax 10from torch import nn 11 12 13class Linear(nn.Linear): 14 def forward(self, input): 15 ci, co = dims() 16 b = dimlists() 17 result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co] 18 return result.order(b, co) 19 20 21class BertSelfAttention(nn.Module): 22 def __init__( 23 self, 24 hidden_size, 25 num_attention_heads, 26 attention_probs_dropout_prob, 27 position_embedding_type=None, 28 max_position_embeddings=None, 29 linear=Linear, 30 ): 31 super().__init__() 32 if hidden_size % num_attention_heads != 0: 33 raise ValueError( 34 f"The hidden size ({hidden_size}) is not a multiple of the number of attention " 35 f"heads ({num_attention_heads})" 36 ) 37 38 self.num_attention_heads = num_attention_heads 39 self.attention_head_size = int(hidden_size / num_attention_heads) 40 self.all_head_size = self.num_attention_heads * self.attention_head_size 41 42 self.query = linear(hidden_size, self.all_head_size) 43 self.key = linear(hidden_size, self.all_head_size) 44 self.value = linear(hidden_size, self.all_head_size) 45 46 self.dropout_prob = attention_probs_dropout_prob 47 self.position_embedding_type = position_embedding_type 48 49 if self.position_embedding_type is not None: 50 assert max_position_embeddings is not None 51 self.max_position_embeddings = max_position_embeddings 52 self.distance_embedding = nn.Embedding( 53 2 * max_position_embeddings - 1, self.attention_head_size 54 ) 55 56 def forward( 57 self, 58 hidden_states, 59 past_key_value=None, 60 ): 61 # first run the encoding linear layers for q, k, v normally 62 # the meaning of a linear layer is well understood, so no need to use explicit dimensions 63 q = self.query(hidden_states) 64 k = self.key(hidden_states) 65 v = self.value(hidden_states) 66 67 # introduce values that represent each dimension. dimensions are 'first class' 68 # because they are actual python values introduced here 69 batch, query_sequence, key_sequence, heads, features = dims() 70 heads.size = self.num_attention_heads 71 72 # bind the positional dimensions in k, q, and v against 73 # our values. the sizes of each dimension are determined by this binding 74 # and when a dimension is used twice (e.g. batch), its size against both 75 # uses is checked for consistency. 76 # The group (heads, features) splits apart a single positional dimension 77 # into two dimensions. Since heads.size*features.size == q.size(2) 78 # and we specified heads.size, features.size is inferred here. 79 q = q[batch, query_sequence, [heads, features]] 80 k = k[batch, key_sequence, [heads, features]] 81 v = v[batch, key_sequence, [heads, features]] 82 83 # this option allows the model to attend to not just the elements of the current sequence 84 # but the previous elements as well as additional tokens. 85 if past_key_value is not None: 86 extended_key_sequence = dims() 87 key_past = past_key_value[0][batch, heads, key_sequence, features] 88 value_past = past_key_value[1][batch, heads, key_sequence, features] 89 # cat introduces a new dimension extended_key_sequence, because it is twice as long 90 # as the original key_sequence 91 k = cat([key_past, k], key_sequence, extended_key_sequence) 92 v = cat([value_past, v], key_sequence, extended_key_sequence) 93 # for the rest of the function, we will just use extended_key_sequence in lieu of 94 # key_sequence 95 key_sequence = extended_key_sequence 96 97 # Take the dot product between "query" and "key" to get the raw attention scores. 98 # The actual outer-product and summation are explicitly represented here, 99 # and like einsum, will be pattern matched to an efficient matrix multiply op. 100 attention_scores = (q * k).sum(features) / math.sqrt(features.size) 101 102 # relative positional embeddings gave a unique embedding based on the distance between 103 # key and value tokens in the sequence, e.g. 104 # 0 1 2 3 105 # -1 0 1 2 106 # -2 -1 0 1 107 # -3 -2 -1 0 108 if self.position_embedding_type is not None: 109 # the value of a dimension object when used as a tensor is the indices along its dimension 110 # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence) 111 # with the distance between them 112 distance = query_sequence - key_sequence 113 114 assert key_sequence.size <= self.max_position_embeddings 115 116 # we can then use that as an indirect index into the embedding table values to look up the features for that index 117 # this is just a `gather` primitive op. The resulting tensor will 118 # have all the dimensions of embeddeding_idx (query_sequence x key_sequence), 119 # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`). 120 # this form of indirect indexing is more straightforward than either advanced indexing or torch.gather which both 121 # have a lot of dependencies on the positions of indexing tensors. 122 123 positional_embedding = self.distance_embedding.weight[ 124 self.max_position_embeddings - 1 + distance, features 125 ] 126 127 if self.position_embedding_type == "relative_key": 128 # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators 129 # eventhough they are degenerate matmuls 130 relative_position_scores = (q * positional_embedding).sum(features) 131 attention_scores = attention_scores + relative_position_scores 132 elif self.position_embedding_type == "relative_key_query": 133 relative_position_scores_query = (q * positional_embedding).sum( 134 features 135 ) 136 relative_position_scores_key = (k * positional_embedding).sum(features) 137 attention_scores = ( 138 attention_scores 139 + relative_position_scores_query 140 + relative_position_scores_key 141 ) 142 143 attention_probs = attention_scores 144 # Normalize the attention scores to probabilities. 145 attention_probs = softmax(attention_scores, dim=key_sequence) 146 # # This is actually dropping out entire tokens to attend to, which might 147 # # seem a bit unusual, but is taken from the original Transformer paper. 148 attention_probs = torch.nn.functional.dropout( 149 attention_probs, p=self.dropout_prob 150 ) 151 152 # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear 153 # we are weighting the values v across all keys with the attention scores. 154 context_layer = (attention_probs * v).sum(key_sequence) 155 156 # finally, we convert back to a standard tensor by describing the layout of dimensions. 157 # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one. 158 return context_layer.order(batch, query_sequence, [heads, features]) 159