1# Owner(s): ["module: onnx"] 2 3import torch 4import torch.nn as nn 5 6 7class DummyNet(nn.Module): 8 def __init__(self, num_classes=1000): 9 super().__init__() 10 self.features = nn.Sequential( 11 nn.LeakyReLU(0.02), 12 nn.BatchNorm2d(3), 13 nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False), 14 ) 15 16 def forward(self, x): 17 output = self.features(x) 18 return output.view(-1, 1).squeeze(1) 19 20 21class ConcatNet(nn.Module): 22 def forward(self, inputs): 23 return torch.cat(inputs, 1) 24 25 26class PermuteNet(nn.Module): 27 def forward(self, input): 28 return input.permute(2, 3, 0, 1) 29 30 31class PReluNet(nn.Module): 32 def __init__(self) -> None: 33 super().__init__() 34 self.features = nn.Sequential( 35 nn.PReLU(3), 36 ) 37 38 def forward(self, x): 39 output = self.features(x) 40 return output 41 42 43class FakeQuantNet(nn.Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.fake_quant = torch.ao.quantization.FakeQuantize() 47 self.fake_quant.disable_observer() 48 49 def forward(self, x): 50 output = self.fake_quant(x) 51 return output 52