xref: /aosp_15_r20/external/pytorch/test/onnx/model_defs/op_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import torch
4import torch.nn as nn
5
6
7class DummyNet(nn.Module):
8    def __init__(self, num_classes=1000):
9        super().__init__()
10        self.features = nn.Sequential(
11            nn.LeakyReLU(0.02),
12            nn.BatchNorm2d(3),
13            nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False),
14        )
15
16    def forward(self, x):
17        output = self.features(x)
18        return output.view(-1, 1).squeeze(1)
19
20
21class ConcatNet(nn.Module):
22    def forward(self, inputs):
23        return torch.cat(inputs, 1)
24
25
26class PermuteNet(nn.Module):
27    def forward(self, input):
28        return input.permute(2, 3, 0, 1)
29
30
31class PReluNet(nn.Module):
32    def __init__(self) -> None:
33        super().__init__()
34        self.features = nn.Sequential(
35            nn.PReLU(3),
36        )
37
38    def forward(self, x):
39        output = self.features(x)
40        return output
41
42
43class FakeQuantNet(nn.Module):
44    def __init__(self) -> None:
45        super().__init__()
46        self.fake_quant = torch.ao.quantization.FakeQuantize()
47        self.fake_quant.disable_observer()
48
49    def forward(self, x):
50        output = self.fake_quant(x)
51        return output
52