xref: /aosp_15_r20/external/pytorch/test/mobile/lightweight_dispatch/tests_setup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import os
3import shutil
4import sys
5from io import BytesIO
6
7import torch
8from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
9
10
11_OPERATORS = set()
12_FILENAMES = []
13_MODELS = []
14
15
16def save_model(cls):
17    """Save a model and dump all the ops"""
18
19    @functools.wraps(cls)
20    def wrapper_save():
21        _MODELS.append(cls)
22        model = cls()
23        scripted = torch.jit.script(model)
24        buffer = BytesIO(scripted._save_to_buffer_for_lite_interpreter())
25        buffer.seek(0)
26        mobile_module = _load_for_lite_interpreter(buffer)
27        ops = _export_operator_list(mobile_module)
28        _OPERATORS.update(ops)
29        path = f"./{cls.__name__}.ptl"
30        _FILENAMES.append(path)
31        scripted._save_for_lite_interpreter(path)
32
33    return wrapper_save
34
35
36@save_model
37class ModelWithDTypeDeviceLayoutPinMemory(torch.nn.Module):
38    def forward(self, x: int):
39        a = torch.ones(
40            size=[3, x],
41            dtype=torch.int64,
42            layout=torch.strided,
43            device="cpu",
44            pin_memory=False,
45        )
46        return a
47
48
49@save_model
50class ModelWithTensorOptional(torch.nn.Module):
51    def forward(self, index):
52        a = torch.zeros(2, 2)
53        a[0][1] = 1
54        a[1][0] = 2
55        a[1][1] = 3
56        return a[index]
57
58
59# gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
60@save_model
61class ModelWithScalarList(torch.nn.Module):
62    def forward(self, a: int):
63        values = torch.tensor(
64            [4.0, 1.0, 1.0, 16.0],
65        )
66        if a == 0:
67            return torch.gradient(
68                values, spacing=torch.scalar_tensor(2.0, dtype=torch.float64)
69            )
70        elif a == 1:
71            return torch.gradient(values, spacing=[torch.tensor(1.0).item()])
72
73
74# upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
75@save_model
76class ModelWithFloatList(torch.nn.Upsample):
77    def __init__(self) -> None:
78        super().__init__(
79            scale_factor=(2.0,),
80            mode="linear",
81            align_corners=False,
82            recompute_scale_factor=True,
83        )
84
85
86# index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
87@save_model
88class ModelWithListOfOptionalTensors(torch.nn.Module):
89    def forward(self, index):
90        values = torch.tensor([[4.0, 1.0, 1.0, 16.0]])
91        return values[torch.tensor(0), index]
92
93
94# conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1,
95# int groups=1) -> Tensor
96@save_model
97class ModelWithArrayOfInt(torch.nn.Conv2d):
98    def __init__(self) -> None:
99        super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
100
101
102# add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
103# ones_like(Tensor self, *, ScalarType?, dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None,
104# MemoryFormat? memory_format=None) -> Tensor
105@save_model
106class ModelWithTensors(torch.nn.Module):
107    def forward(self, a):
108        b = torch.ones_like(a)
109        return a + b
110
111
112@save_model
113class ModelWithStringOptional(torch.nn.Module):
114    def forward(self, b):
115        a = torch.tensor(3, dtype=torch.int64)
116        out = torch.empty(size=[1], dtype=torch.float)
117        torch.div(b, a, out=out)
118        return [torch.div(b, a, rounding_mode="trunc"), out]
119
120
121@save_model
122class ModelWithMultipleOps(torch.nn.Module):
123    def __init__(self) -> None:
124        super().__init__()
125        self.ops = torch.nn.Sequential(
126            torch.nn.ReLU(),
127            torch.nn.Flatten(),
128        )
129
130    def forward(self, x):
131        x[1] = -2
132        return self.ops(x)
133
134
135if __name__ == "__main__":
136    command = sys.argv[1]
137    ops_yaml = sys.argv[2]
138    backup = ops_yaml + ".bak"
139    if command == "setup":
140        tests = [
141            ModelWithDTypeDeviceLayoutPinMemory(),
142            ModelWithTensorOptional(),
143            ModelWithScalarList(),
144            ModelWithFloatList(),
145            ModelWithListOfOptionalTensors(),
146            ModelWithArrayOfInt(),
147            ModelWithTensors(),
148            ModelWithStringOptional(),
149            ModelWithMultipleOps(),
150        ]
151        shutil.copyfile(ops_yaml, backup)
152        with open(ops_yaml, "a") as f:
153            for op in _OPERATORS:
154                f.write(f"- {op}\n")
155    elif command == "shutdown":
156        for file in _MODELS:
157            if os.path.isfile(file):
158                os.remove(file)
159        shutil.move(backup, ops_yaml)
160