xref: /aosp_15_r20/external/executorch/exir/tests/models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import itertools
10from typing import Any, List, Optional, Tuple, Union
11
12import executorch.exir as exir
13
14import torch  # noqa: F401
15import torch.nn as nn
16from executorch.exir import to_edge
17from executorch.exir.lowered_backend_module import LoweredBackendModule
18from torch import Tensor
19from torch.export import export
20
21# TODO: add one more test for data dependent op plus repeat
22
23
24class TensorItem(nn.Module):
25    def __init__(self) -> None:
26        super().__init__()
27
28    def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
29        h = arg1.item()
30        w = arg2.item()
31        torch._check(h >= 2)
32        torch._check(h <= 100)
33        torch._check(w >= 2)
34        torch._check(w <= 100)
35        return torch.ones(int(h), int(w))
36
37    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
38        return (torch.tensor(10), torch.tensor(20))
39
40
41class Repeat(nn.Module):
42    def __init__(self) -> None:
43        super().__init__()
44
45    def forward(
46        self, arg1: torch.Tensor, arg2: torch.Tensor
47    ) -> Tuple[torch.Tensor, torch.Tensor]:
48        x = arg2.repeat(arg1.size(0), 1)
49        return x * x, arg2 + arg2
50
51    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
52        return (torch.rand(4), torch.rand(5))
53
54    def get_dynamic_shape(self) -> Any:  # pyre-ignore[3]
55        dim = torch.export.Dim("dim", max=10)
56        dim2 = torch.export.Dim("dim2", max=10)
57        return ({0: dim}, {0: dim2})
58
59
60class ModelWithUnusedArg(nn.Module):
61    def __init__(self) -> None:
62        super().__init__()
63
64    def forward(self, arg1: torch.Tensor, arg2: torch.Tensor) -> torch.Tensor:
65        return torch.sin(arg1)
66
67    def get_random_inputs(self) -> Tuple[torch.Tensor, torch.Tensor]:
68        return (torch.rand(4), torch.rand(5))
69
70
71class MLP(nn.Module):
72    def __init__(self, n_layer: int = 1, output_size: int = 1) -> None:
73        super().__init__()
74        self.n_layer = n_layer
75        self.output_size = output_size
76        # input shape [batch_size, n_layer+output_size]
77        # each linear layer reduce the activation dim 1 size by 1.
78        self.mlp = torch.nn.Sequential(
79            *itertools.chain(
80                *(
81                    [nn.Linear(i + output_size, i - 1 + output_size)]
82                    + ([nn.ReLU()] if i != 1 else [])
83                    for i in range(n_layer, 0, -1)
84                )
85            )
86        )
87
88    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
89        return self.mlp(inputs)
90
91    def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
92        return (torch.rand(2, self.n_layer + self.output_size),)
93
94
95class Identity(nn.Module):
96    def __init__(self) -> None:
97        super().__init__()
98
99    def forward(self, input: Tensor) -> Tensor:
100        return torch.clone(input)
101
102
103class Reshape(nn.Module):
104    def __init__(self) -> None:
105        super().__init__()
106
107    def forward(
108        self, x: Tensor, *new_shape: Union[torch.Size, Tuple[int, ...], List[int]]
109    ) -> Tensor:
110        if len(new_shape) == 1 and (
111            isinstance(new_shape[0], tuple) or isinstance(new_shape[0], list)
112        ):
113            return x.reshape(new_shape[0])
114        assert isinstance(new_shape, Union[torch.Size, Tuple[int, ...], List[int]])
115        return x.reshape(new_shape)
116
117
118class Transpose(nn.Module):
119    def __init__(self) -> None:
120        super().__init__()
121
122    def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
123        return x.transpose(dim0, dim1)
124
125
126class Mul(nn.Module):
127    def __init__(self) -> None:
128        super().__init__()
129
130    def forward(self, input: Tensor, other: Tensor) -> Tensor:
131        # or return torch.mul(input, other)
132        return input * other
133
134    def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
135        return (torch.randn(3, 2), torch.randn(3, 2))
136
137
138class ElementwiseAdd(nn.Module):
139    def __init__(self) -> None:
140        super().__init__()
141
142    def forward(self, x: Tensor, y: Tensor) -> Tensor:
143        return x + y
144
145    def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
146        return (torch.randn(1, 3), torch.randn(1, 3))
147
148
149class BasicSinMax(nn.Module):
150    def __init__(self) -> None:
151        super().__init__()
152
153    def forward(self, x: Tensor) -> Tensor:
154        return torch.sin(x)
155
156    def get_random_inputs(self) -> Tuple[Tensor]:
157        return (torch.randn(100),)
158
159
160class CompositeDelegateModule(torch.nn.Module):
161    def __init__(self) -> None:
162        super().__init__()
163
164        class DelegateAdd(nn.Module):
165            def __init__(self) -> None:
166                super().__init__()
167
168            def forward(self, x: Tensor, y: Tensor) -> Tensor:
169                return [x + y]
170
171            def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
172                return (torch.randn(1, 3), torch.randn(1, 3))
173
174        delegated_m = DelegateAdd()
175        edge_ir_m = to_edge(
176            export(
177                delegated_m,
178                delegated_m.get_random_inputs(),
179            )
180        )
181        lowered_module = LoweredBackendModule(
182            edge_program=edge_ir_m.exported_program(),
183            backend_id="backend_demo",
184            processed_bytes=bytes("basic_module_add", encoding="utf8"),
185            compile_specs=[],
186        )
187        self.lowered_module: LoweredBackendModule = lowered_module
188
189    def forward(self, a: exir.Value, b: exir.Value, s: Tensor) -> Tensor:
190        res = self.lowered_module(a, b)
191        res = res[0] * s
192        return res
193
194    def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor]:
195        return (torch.randn(1, 3), torch.randn(1, 3), torch.randn(1, 3))
196
197
198class BatchMatrixMultiplication(nn.Module):
199    def __init__(self, transposed: bool = False) -> None:
200        super().__init__()
201
202        # Whether the last 2 dims (-1, -2) of the input has already been
203        # transposed. If yes, transpose it back before feeding to torch.bmm
204        self.transposed: bool = transposed
205
206    def forward(self, x: Tensor, y: Tensor) -> Tensor:
207        if self.transposed:
208            return torch.bmm(x, y.transpose(-1, -2))
209        else:
210            return torch.bmm(x, y)
211
212    def extra_repr(self) -> str:
213        return f"transposed={self.transposed}"
214
215
216class TensorSplit(nn.Module):
217    def __init__(self) -> None:
218        super().__init__()
219
220    def forward(self, input: Tensor, sections: int, dim: int = 0) -> List[Tensor]:
221        # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`.
222        return torch.tensor_split(input, sections, dim)
223
224
225class TensorSplitWithSizes(nn.Module):
226    def __init__(self) -> None:
227        super().__init__()
228
229    def forward(self, input: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
230        # pyre-fixme[7]: Expected `List[Tensor]` but got `Tuple[Tensor, ...]`.
231        return torch.split(input, split_size, dim)
232
233
234class Cat(nn.Module):
235    def __init__(self) -> None:
236        super().__init__()
237
238    # def forward(self, tensors, dim=0):
239    def forward(self, *args: Tensor, dim: int) -> Tensor:
240        tensors = args[:-1]
241        return torch.cat(tensors, dim)
242
243
244class FeedForwardBlock(nn.Module):
245    def __init__(self, input_dim: int, hidden_dim: int) -> None:
246        super().__init__()
247        self.input_dim = input_dim
248        self.hidden_dim = hidden_dim
249
250        self.layer_norm = nn.LayerNorm(input_dim)
251
252        self.relu = nn.ReLU()
253
254        self.linear1 = nn.Linear(input_dim, hidden_dim)
255        self.dropout1 = nn.Dropout()
256
257        self.linear2 = nn.Linear(hidden_dim, input_dim)
258        self.dropout2 = nn.Dropout()
259
260    def forward(self, x: Tensor) -> Tensor:
261        # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
262        y = self.layer_norm(x)
263        y = self.linear1(y)
264        y = self.dropout1(y)
265        y = self.relu(y)
266        y = self.linear2(y)
267        y = self.dropout2(y)
268        return y
269
270
271class NoOp(nn.Module):
272    """
273    NoOp simply passes the input as the output.
274    """
275
276    def __init__(self) -> None:
277        super().__init__()
278
279    def forward(self, input: Tensor) -> Tensor:
280        return input
281
282
283class MultiLayerPerceptron(nn.Module):
284    def __init__(
285        self,
286        input_dim: int,
287        hidden_dim1: int,
288        hidden_dim2: int,
289        hidden_dim3: int,
290        output_dim: int,
291    ) -> None:
292        super().__init__()
293        self.input_dim = input_dim
294        self.hidden_dim1 = hidden_dim1
295        self.hidden_dim2 = hidden_dim2
296        self.hidden_dim3 = hidden_dim3
297        self.output_dim = output_dim
298        self.layers = nn.Sequential(
299            nn.Linear(input_dim, hidden_dim1),
300            nn.ReLU(),
301            nn.Linear(hidden_dim1, hidden_dim2),
302            nn.ReLU(),
303            nn.Linear(hidden_dim2, hidden_dim3),
304            nn.ReLU(),
305            nn.Linear(hidden_dim3, output_dim),
306        )
307
308    def forward(self, x: Tensor) -> Tensor:
309        return self.layers(x)
310
311
312class ScaledDotProductAttentionModularized(nn.Module):
313    def __init__(
314        self,
315        embed_dim: int,
316        num_heads: int,
317        dropout_p: float = 0.5,
318    ) -> None:
319        super().__init__()
320        self.embed_dim = embed_dim
321        self.num_heads = num_heads
322        self.dropout_p = dropout_p
323        self.dropout = nn.Dropout(p=dropout_p)
324
325        self.head_dim: int = embed_dim // num_heads
326        self.scaling: float = self.head_dim**-0.5
327
328        self.mul = Mul()
329        self.reshape = Reshape()
330        self.transpose = Transpose()
331        self.bmm = BatchMatrixMultiplication(transposed=False)
332        self.bmm_t = BatchMatrixMultiplication(transposed=True)
333        self.softmax = nn.Softmax(dim=-1)
334
335    def forward(
336        self,
337        q: Tensor,
338        k: Tensor,
339        v: Tensor,
340    ) -> Tensor:
341        # q: (L, B, D) k: (S, B, D) v: (S, B, D)
342        # assert k.shape == v.shape
343        # assert q.dim() == 3 and k.dim() == 3
344        # assert q.size(1) == k.size(1) and q.size(2) == k.size(2)
345
346        L, B, D = q.shape
347        S = k.size(0)
348        # assert D % self.head_dim == 0
349
350        # FIXME(poweic): scaling layer!?
351        # this will break the modular assumption, which makes the following
352        # self.reshape to think it is using some floating inputs q because
353        # id(q) is no longer the same id(q)
354        # This is equiv. to `q = q * self.scaling`
355        q = self.mul(q, self.scaling)
356
357        # Reshape & transpose q from (L, B, D) to (B*H, L, D/H)
358        q = self.reshape(q, (L, B * self.num_heads, self.head_dim))
359        q = self.transpose(q, 0, 1)
360
361        # Reshape & transpose k from (S, B, D) to (B*H, S, D/H)
362        k = self.reshape(k, (S, B * self.num_heads, self.head_dim))
363        k = self.transpose(k, 0, 1)
364
365        # Reshape & transpose v from (S, B, D) to (B*H, S, D/H)
366        v = self.reshape(v, (S, B * self.num_heads, self.head_dim))
367        v = self.transpose(v, 0, 1)
368
369        # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
370        # this is equiv. to `qk = torch.bmm(q, k.transpose(-1, -2))`
371        qk = self.bmm_t(q, k)
372        # assert qk.shape == (B * self.num_heads, L, S)
373
374        softmax_qk = self.softmax(qk)
375
376        softmax_qk = self.dropout(softmax_qk)
377
378        # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
379        # this is equiv. to `attention = torch.bmm(softmax_qk, v)`
380        attention = self.bmm(softmax_qk, v)
381        # assert attention.shape == (B * self.num_heads, L, self.head_dim)
382
383        # Transpose & reshape attention: (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D).
384        attention = self.transpose(attention, 0, 1)
385        attention = self.reshape(attention, (L, B, self.embed_dim))
386
387        return attention
388
389
390# ------------------------------------------------------------------------------
391#   Scaled Dot-Product Attention
392# ------------------------------------------------------------------------------
393class ScaledDotProductAttention(nn.Module):
394    def __init__(
395        self,
396        embed_dim: int,
397        num_heads: int,
398        dropout: Optional[float] = None,
399    ) -> None:
400        if embed_dim % num_heads:
401            raise ValueError(
402                "embed_dim ({}) must be divisible by num_heads ({})".format(
403                    embed_dim, num_heads
404                )
405            )
406
407        super().__init__()
408
409        self.embed_dim = embed_dim
410        self.num_heads = num_heads
411        if dropout is not None and dropout > 0.0:
412            self.dropout: nn.Module = nn.Dropout(p=dropout)
413        else:
414            self.dropout = NoOp()
415
416        self.head_dim: int = embed_dim // num_heads
417        self.scaling: float = self.head_dim**-0.5
418
419    def forward(
420        self,
421        q: Tensor,
422        k: Tensor,
423        v: Tensor,
424        padding_mask: Optional[Tensor] = None,
425        attention_mask: Optional[Tensor] = None,
426    ) -> Tensor:
427        # q: (L, B, D) k: (S, B, D) v: (S, B, D)
428        # assert k.shape == v.shape
429        # assert q.dim() == 3 and k.dim() == 3
430        # assert q.size(1) == k.size(1) and q.size(2) == k.size(2)
431
432        L, B, D = q.shape
433        S = k.size(0)
434        # assert D % self.head_dim == 0
435
436        q = q * self.scaling
437        q = q.reshape(L, B * self.num_heads, self.head_dim).transpose(
438            0, 1
439        )  # (B*H, L, D/H)
440
441        k = k.reshape(S, B * self.num_heads, self.head_dim).transpose(
442            0, 1
443        )  # (B*H, S, D/H)
444
445        v = v.reshape(S, B * self.num_heads, self.head_dim).transpose(
446            0, 1
447        )  # (B*H, S, D/H)
448
449        # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
450        qk = torch.bmm(q, k.transpose(1, 2))
451        # assert qk.shape == (B * self.num_heads, L, S)
452
453        # TODO(cfyeh): figure out if we really need input to be float.
454        softmax_qk = nn.functional.softmax(qk.float(), dim=-1)
455
456        # softmax_qk = self.dropout(softmax_qk)
457
458        # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
459        attention = torch.bmm(softmax_qk, v)
460        # assert attention.shape == (B * self.num_heads, L, self.head_dim)
461
462        # (B*H, L, D/H) -> (L, B*H, D/H) -> (L, B, D).
463        attention = attention.transpose(0, 1).reshape(L, B, self.embed_dim)
464
465        return attention
466
467
468class Emformer(nn.Module):
469    def __init__(
470        self,
471        l_dim: int = 32,
472        m_dim: int = 8,
473        c_dim: int = 8,
474        r_dim: int = 8,
475        input_dim: int = 256,
476        ffn_hidden_dim: int = 512,
477    ) -> None:
478        super().__init__()
479
480        self.l_dim = l_dim
481        self.m_dim = m_dim
482        self.c_dim = c_dim
483        self.r_dim = r_dim
484
485        self.input_dim = input_dim
486        self.ffn_hidden_dim = ffn_hidden_dim
487
488        self.split = TensorSplit()
489        self.elem_add = ElementwiseAdd()
490
491        self.attn = ScaledDotProductAttention(
492            embed_dim=input_dim,
493            num_heads=8,
494        )
495
496        self.ffn = FeedForwardBlock(input_dim, ffn_hidden_dim)
497
498        self.layer_norm = nn.LayerNorm(input_dim)
499
500        self.linear_k = nn.Linear(self.input_dim, self.input_dim)
501        self.linear_v = nn.Linear(self.input_dim, self.input_dim)
502        self.linear_q = nn.Linear(self.input_dim, self.input_dim)
503
504    def get_random_inputs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
505        inputs = (
506            torch.randn(self.m_dim, 1, self.input_dim),
507            torch.randn(self.c_dim, 1, self.input_dim),
508            torch.randn(self.r_dim, 1, self.input_dim),
509            torch.randn(self.l_dim, 1, self.input_dim),
510            torch.randn(self.l_dim, 1, self.input_dim),
511        )
512        return inputs
513
514    def forward(
515        self, M: Tensor, C: Tensor, R: Tensor, K_L: Tensor, V_L: Tensor
516    ) -> Tensor:
517        """
518        The Emformer block takes [M_i^n, C_i^n, R_i^n] and [K_{L,i}^n, V_{L,i}^n]
519        as inputs and outputs [C_i^{n+1}, R_i^{n+1}].
520        See Fig. 1(b) Emformer and equations 6, 7, 8 - 13 in the original paper
521        https://arxiv.org/pdf/2010.10759.pdf
522
523        Ex:
524         - self.input_dim =
525         - L.shape = 30 x 1 x 512
526         - M.shape =  2 x 1 x 512
527         - C.shape =  5 x 1 x 512
528         - R.shape =  1 x 1 x 512
529        """
530        # Equation 8
531        CR = torch.cat([C, R], 0)
532        CR_normed = self.layer_norm(CR)
533        # C_normed = self.layer_norm(C)
534        # R_normed = self.layer_norm(R)
535
536        # Equation 9 and 10
537        if True:
538            MCR = torch.cat([M, C, R], 0)
539            K_MCR = self.linear_k(MCR)
540            V_MCR = self.linear_v(MCR)
541
542            K_M, K_C, K_R = self.split(K_MCR, 3)
543            V_M, V_C, V_R = self.split(V_MCR, 3)
544        else:
545            K_M, K_C, K_R = self.linear_k(M), self.linear_k(C), self.linear_k(R)
546            V_M, V_C, V_R = self.linear_v(M), self.linear_v(C), self.linear_v(R)
547
548        K = torch.cat([K_M, K_L, K_C, K_R], 0)
549        V = torch.cat([V_M, V_L, V_C, V_R], 0)
550
551        # Equation 11 and 12
552        Q_CR = self.linear_q(CR_normed)
553        Z_CR = self.attn(Q_CR, K, V)
554        Z_CR = self.elem_add(Z_CR, CR)
555        # Q_C = self.linear_q(C_normed)
556        # Q_R = self.linear_q(R_normed)
557        # Z_C = self.attn(Q_C, K, V)
558        # Z_R = self.attn(Q_R, K, V)
559        # Z_C = self.elem_add(Z_C, C)
560        # Z_R = self.elem_add(Z_R, R)
561
562        # Equation 6
563        Z_CR_normed = self.layer_norm(Z_CR)
564        ffn_out = self.ffn(Z_CR_normed)
565
566        # Equation 7
567        output = self.layer_norm(self.elem_add(ffn_out, Z_CR))
568
569        # m = self.attn(
570
571        return output
572
573
574# List of models that we want to export
575# TODO(angelayi): enable ControlFlowWhile test once we enable functionalization
576MODELS = [
577    ["basic_sin_max", BasicSinMax()],
578    ["composite_delegate", CompositeDelegateModule()],
579]
580