xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
2*da0073e9SAndroid Build Coastguard Worker# for benchmarking and some control flow stripped out.
3*da0073e9SAndroid Build Coastguard Worker# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerclass BahdanauAttention(benchmark.Benchmark):
11*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, b, t_q, t_k, n):
12*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
13*da0073e9SAndroid Build Coastguard Worker        self.b = b
14*da0073e9SAndroid Build Coastguard Worker        self.t_q = t_q
15*da0073e9SAndroid Build Coastguard Worker        self.t_k = t_k
16*da0073e9SAndroid Build Coastguard Worker        self.n = n
17*da0073e9SAndroid Build Coastguard Worker        self.att_query = self.rand(
18*da0073e9SAndroid Build Coastguard Worker            [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
19*da0073e9SAndroid Build Coastguard Worker        )
20*da0073e9SAndroid Build Coastguard Worker        self.att_keys = self.rand(
21*da0073e9SAndroid Build Coastguard Worker            [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
22*da0073e9SAndroid Build Coastguard Worker        )
23*da0073e9SAndroid Build Coastguard Worker        self.normalize_bias = self.rand(
24*da0073e9SAndroid Build Coastguard Worker            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
25*da0073e9SAndroid Build Coastguard Worker        )
26*da0073e9SAndroid Build Coastguard Worker        self.linear_att = self.rand(
27*da0073e9SAndroid Build Coastguard Worker            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
28*da0073e9SAndroid Build Coastguard Worker        )
29*da0073e9SAndroid Build Coastguard Worker        self.inputs = [
30*da0073e9SAndroid Build Coastguard Worker            self.att_query,
31*da0073e9SAndroid Build Coastguard Worker            self.att_keys,
32*da0073e9SAndroid Build Coastguard Worker            self.normalize_bias,
33*da0073e9SAndroid Build Coastguard Worker            self.linear_att,
34*da0073e9SAndroid Build Coastguard Worker        ]
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def forward(self, att_query, att_keys, normalize_bias, linear_att):
37*da0073e9SAndroid Build Coastguard Worker        """
38*da0073e9SAndroid Build Coastguard Worker        Calculate Bahdanau score
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        :param att_query: b x t_q x n
41*da0073e9SAndroid Build Coastguard Worker        :param att_keys: b x t_k x n
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        return b x t_q x t_k scores
44*da0073e9SAndroid Build Coastguard Worker        """
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        b, t_k, n = att_keys.size()
47*da0073e9SAndroid Build Coastguard Worker        t_q = att_query.size(1)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
50*da0073e9SAndroid Build Coastguard Worker        att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
51*da0073e9SAndroid Build Coastguard Worker        sum_qk = att_query + att_keys + normalize_bias
52*da0073e9SAndroid Build Coastguard Worker        out = torch.tanh(sum_qk).matmul(linear_att)
53*da0073e9SAndroid Build Coastguard Worker        return out
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def reference(self):
56*da0073e9SAndroid Build Coastguard Worker        return self.numpy(self.forward(*self.inputs))
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def config(self):
59*da0073e9SAndroid Build Coastguard Worker        return [self.b, self.t_q, self.t_k, self.n]
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    @staticmethod
62*da0073e9SAndroid Build Coastguard Worker    def module():
63*da0073e9SAndroid Build Coastguard Worker        return "attention"
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
66*da0073e9SAndroid Build Coastguard Worker        def memsize(t):
67*da0073e9SAndroid Build Coastguard Worker            return t.numel() * t.element_size()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        input_size = (
70*da0073e9SAndroid Build Coastguard Worker            memsize(self.att_query)
71*da0073e9SAndroid Build Coastguard Worker            + memsize(self.att_keys)
72*da0073e9SAndroid Build Coastguard Worker            + memsize(self.normalize_bias)
73*da0073e9SAndroid Build Coastguard Worker            + memsize(self.linear_att)
74*da0073e9SAndroid Build Coastguard Worker        )
75*da0073e9SAndroid Build Coastguard Worker        output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
76*da0073e9SAndroid Build Coastguard Worker        io_size = input_size + output_size
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        # If matmul is not fused, must write and then read `sum_qk`.
79*da0073e9SAndroid Build Coastguard Worker        intermediate_size = (
80*da0073e9SAndroid Build Coastguard Worker            2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
81*da0073e9SAndroid Build Coastguard Worker        )
82*da0073e9SAndroid Build Coastguard Worker        return {"sol": io_size, "algorithmic": io_size + intermediate_size}
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @staticmethod
85*da0073e9SAndroid Build Coastguard Worker    def default_configs():
86*da0073e9SAndroid Build Coastguard Worker        mlperf_inference = [1280, 1, 66, 1024]
87*da0073e9SAndroid Build Coastguard Worker        nvidia = [128, 10, 128, 1024]
88*da0073e9SAndroid Build Coastguard Worker        return [mlperf_inference, nvidia]
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(BahdanauAttention)
92