xref: /aosp_15_r20/external/executorch/examples/models/toy_model/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.backend.compile_spec_schema import CompileSpec
9
10from ..model_base import EagerModelBase
11
12
13class MulModule(torch.nn.Module, EagerModelBase):
14    def __init__(self) -> None:
15        super().__init__()
16
17    def forward(self, input, other):
18        return input * other
19
20    def get_eager_model(self) -> torch.nn.Module:
21        return self
22
23    def get_example_inputs(self):
24        return (torch.randn(3, 2), torch.randn(3, 2))
25
26
27class LinearModule(torch.nn.Module, EagerModelBase):
28    def __init__(self):
29        super().__init__()
30        self.linear = torch.nn.Linear(3, 3)
31
32    def forward(self, arg):
33        return self.linear(arg)
34
35    def get_eager_model(self) -> torch.nn.Module:
36        return self
37
38    def get_example_inputs(self):
39        return (torch.randn(3, 3),)
40
41
42class AddModule(torch.nn.Module, EagerModelBase):
43    def __init__(self):
44        super().__init__()
45
46    def forward(self, x, y):
47        z = x + y
48        return z
49
50    def get_eager_model(self) -> torch.nn.Module:
51        return self
52
53    def get_example_inputs(self):
54        return (torch.ones(1), torch.ones(1))
55
56
57class AddMulModule(torch.nn.Module, EagerModelBase):
58    def __init__(self):
59        super().__init__()
60
61    def forward(self, a, x, b):
62        y = torch.mm(a, x)
63        z = torch.add(y, b)
64        return z
65
66    def get_eager_model(self) -> torch.nn.Module:
67        return self
68
69    def get_example_inputs(self):
70        return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
71
72    def get_compile_spec(self):
73        max_value = self.get_example_inputs()[0].shape[0]
74        return [CompileSpec("max_value", bytes([max_value]))]
75
76
77class SoftmaxModule(torch.nn.Module, EagerModelBase):
78    def __init__(self):
79        super().__init__()
80        self.softmax = torch.nn.Softmax()
81
82    def forward(self, x):
83        z = self.softmax(x)
84        return z
85
86    def get_eager_model(self) -> torch.nn.Module:
87        return self
88
89    def get_example_inputs(self):
90        return (torch.ones(2, 2),)
91