1*da0073e9SAndroid Build Coastguard Worker# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded 2*da0073e9SAndroid Build Coastguard Worker# for benchmarking and some control flow stripped out. 3*da0073e9SAndroid Build Coastguard Worker# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerclass BahdanauAttention(benchmark.Benchmark): 11*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, b, t_q, t_k, n): 12*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 13*da0073e9SAndroid Build Coastguard Worker self.b = b 14*da0073e9SAndroid Build Coastguard Worker self.t_q = t_q 15*da0073e9SAndroid Build Coastguard Worker self.t_k = t_k 16*da0073e9SAndroid Build Coastguard Worker self.n = n 17*da0073e9SAndroid Build Coastguard Worker self.att_query = self.rand( 18*da0073e9SAndroid Build Coastguard Worker [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad 19*da0073e9SAndroid Build Coastguard Worker ) 20*da0073e9SAndroid Build Coastguard Worker self.att_keys = self.rand( 21*da0073e9SAndroid Build Coastguard Worker [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad 22*da0073e9SAndroid Build Coastguard Worker ) 23*da0073e9SAndroid Build Coastguard Worker self.normalize_bias = self.rand( 24*da0073e9SAndroid Build Coastguard Worker [n], device=device, dtype=dtype, requires_grad=self.requires_grad 25*da0073e9SAndroid Build Coastguard Worker ) 26*da0073e9SAndroid Build Coastguard Worker self.linear_att = self.rand( 27*da0073e9SAndroid Build Coastguard Worker [n], device=device, dtype=dtype, requires_grad=self.requires_grad 28*da0073e9SAndroid Build Coastguard Worker ) 29*da0073e9SAndroid Build Coastguard Worker self.inputs = [ 30*da0073e9SAndroid Build Coastguard Worker self.att_query, 31*da0073e9SAndroid Build Coastguard Worker self.att_keys, 32*da0073e9SAndroid Build Coastguard Worker self.normalize_bias, 33*da0073e9SAndroid Build Coastguard Worker self.linear_att, 34*da0073e9SAndroid Build Coastguard Worker ] 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def forward(self, att_query, att_keys, normalize_bias, linear_att): 37*da0073e9SAndroid Build Coastguard Worker """ 38*da0073e9SAndroid Build Coastguard Worker Calculate Bahdanau score 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker :param att_query: b x t_q x n 41*da0073e9SAndroid Build Coastguard Worker :param att_keys: b x t_k x n 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker return b x t_q x t_k scores 44*da0073e9SAndroid Build Coastguard Worker """ 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker b, t_k, n = att_keys.size() 47*da0073e9SAndroid Build Coastguard Worker t_q = att_query.size(1) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) 50*da0073e9SAndroid Build Coastguard Worker att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) 51*da0073e9SAndroid Build Coastguard Worker sum_qk = att_query + att_keys + normalize_bias 52*da0073e9SAndroid Build Coastguard Worker out = torch.tanh(sum_qk).matmul(linear_att) 53*da0073e9SAndroid Build Coastguard Worker return out 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def reference(self): 56*da0073e9SAndroid Build Coastguard Worker return self.numpy(self.forward(*self.inputs)) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def config(self): 59*da0073e9SAndroid Build Coastguard Worker return [self.b, self.t_q, self.t_k, self.n] 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @staticmethod 62*da0073e9SAndroid Build Coastguard Worker def module(): 63*da0073e9SAndroid Build Coastguard Worker return "attention" 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 66*da0073e9SAndroid Build Coastguard Worker def memsize(t): 67*da0073e9SAndroid Build Coastguard Worker return t.numel() * t.element_size() 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker input_size = ( 70*da0073e9SAndroid Build Coastguard Worker memsize(self.att_query) 71*da0073e9SAndroid Build Coastguard Worker + memsize(self.att_keys) 72*da0073e9SAndroid Build Coastguard Worker + memsize(self.normalize_bias) 73*da0073e9SAndroid Build Coastguard Worker + memsize(self.linear_att) 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel() 76*da0073e9SAndroid Build Coastguard Worker io_size = input_size + output_size 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker # If matmul is not fused, must write and then read `sum_qk`. 79*da0073e9SAndroid Build Coastguard Worker intermediate_size = ( 80*da0073e9SAndroid Build Coastguard Worker 2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel() 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker return {"sol": io_size, "algorithmic": io_size + intermediate_size} 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @staticmethod 85*da0073e9SAndroid Build Coastguard Worker def default_configs(): 86*da0073e9SAndroid Build Coastguard Worker mlperf_inference = [1280, 1, 66, 1024] 87*da0073e9SAndroid Build Coastguard Worker nvidia = [128, 10, 128, 1024] 88*da0073e9SAndroid Build Coastguard Worker return [mlperf_inference, nvidia] 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(BahdanauAttention) 92