# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import math import torch from torch import nn class BertSelfAttention(nn.Module): def __init__( self, hidden_size, num_attention_heads, attention_probs_dropout_prob, position_embedding_type=None, max_position_embeddings=None, ): super().__init__() if hidden_size % num_attention_heads != 0: raise ValueError( f"The hidden size ({hidden_size}) is not a multiple of the number of attention " f"heads ({num_attention_heads})" ) self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.dropout = nn.Dropout(attention_probs_dropout_prob) self.position_embedding_type = position_embedding_type if self.position_embedding_type is not None: assert max_position_embeddings is not None self.max_position_embeddings = max_position_embeddings self.distance_embedding = nn.Embedding( 2 * max_position_embeddings - 1, self.attention_head_size ) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, past_key_value=None, ): q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) q = self.transpose_for_scores(q) k = self.transpose_for_scores(k) v = self.transpose_for_scores(v) if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(q, k.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if self.position_embedding_type is not None: seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(-1, 1) position_ids_r = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1 ) positional_embedding = positional_embedding.to( dtype=q.dtype ) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", q, positional_embedding ) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", q, positional_embedding ) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", k, positional_embedding ) attention_scores = ( attention_scores + relative_position_scores_query + relative_position_scores_key ) attention_probs = attention_scores # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) # # This is actually dropping out entire tokens to attend to, which might # # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, v) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer