1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport torch 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Worker# module with related operator only 11*523fa7a6SAndroid Build Coastguard Workerclass Add(torch.nn.Module): 12*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 13*523fa7a6SAndroid Build Coastguard Worker super().__init__() 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 16*523fa7a6SAndroid Build Coastguard Worker return torch.add(x, y) 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerclass AddConstantFloat(torch.nn.Module): 20*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 21*523fa7a6SAndroid Build Coastguard Worker super().__init__() 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 24*523fa7a6SAndroid Build Coastguard Worker return 10.0 + x 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Workerclass AddConstantLong(torch.nn.Module): 28*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 29*523fa7a6SAndroid Build Coastguard Worker super().__init__() 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 32*523fa7a6SAndroid Build Coastguard Worker return 10 + x 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerclass Arange(torch.nn.Module): 36*523fa7a6SAndroid Build Coastguard Worker def __init__(self, x): 37*523fa7a6SAndroid Build Coastguard Worker super().__init__() 38*523fa7a6SAndroid Build Coastguard Worker self.x = x 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker def forward(self, y): 41*523fa7a6SAndroid Build Coastguard Worker return torch.arange(self.x, dtype=torch.float32) + y 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Workerclass AvgPoolModule(torch.nn.Module): 45*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 46*523fa7a6SAndroid Build Coastguard Worker super().__init__() 47*523fa7a6SAndroid Build Coastguard Worker self.avgPool = torch.nn.AvgPool2d( 48*523fa7a6SAndroid Build Coastguard Worker kernel_size=(2, 2), 49*523fa7a6SAndroid Build Coastguard Worker padding=(1, 1), 50*523fa7a6SAndroid Build Coastguard Worker stride=(1, 1), 51*523fa7a6SAndroid Build Coastguard Worker count_include_pad=False, 52*523fa7a6SAndroid Build Coastguard Worker ) 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 55*523fa7a6SAndroid Build Coastguard Worker return self.avgPool(x) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Workerclass BatchNorm(torch.nn.Module): 59*523fa7a6SAndroid Build Coastguard Worker def __init__(self, n_features): 60*523fa7a6SAndroid Build Coastguard Worker super().__init__() 61*523fa7a6SAndroid Build Coastguard Worker self.native_batchnorm = torch.nn.BatchNorm2d(n_features) 62*523fa7a6SAndroid Build Coastguard Worker self.eval() 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 65*523fa7a6SAndroid Build Coastguard Worker return self.native_batchnorm(x) 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Workerclass Bmm(torch.nn.Module): 69*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 70*523fa7a6SAndroid Build Coastguard Worker super().__init__() 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 73*523fa7a6SAndroid Build Coastguard Worker return torch.matmul(x, y) 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Workerclass Cast(torch.nn.Module): 77*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 78*523fa7a6SAndroid Build Coastguard Worker super().__init__() 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 81*523fa7a6SAndroid Build Coastguard Worker return x.type(torch.IntTensor) 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Workerclass Cat2(torch.nn.Module): 85*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 86*523fa7a6SAndroid Build Coastguard Worker super().__init__() 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 89*523fa7a6SAndroid Build Coastguard Worker return torch.cat((x, y), axis=2) 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Workerclass Cat3(torch.nn.Module): 93*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 94*523fa7a6SAndroid Build Coastguard Worker super().__init__() 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 97*523fa7a6SAndroid Build Coastguard Worker return torch.concat((y, y, x), axis=2) 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Workerclass Cat4(torch.nn.Module): 101*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 102*523fa7a6SAndroid Build Coastguard Worker super().__init__() 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 105*523fa7a6SAndroid Build Coastguard Worker return torch.cat((y, y, x, x), axis=2) 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Workerclass Ceil(torch.nn.Module): 109*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 110*523fa7a6SAndroid Build Coastguard Worker super().__init__() 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 113*523fa7a6SAndroid Build Coastguard Worker return torch.ceil(x) 114*523fa7a6SAndroid Build Coastguard Worker 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Workerclass Chunk(torch.nn.Module): 117*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 118*523fa7a6SAndroid Build Coastguard Worker super().__init__() 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 121*523fa7a6SAndroid Build Coastguard Worker return torch.chunk(x, chunks=2, dim=-1) 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Workerclass ChunkAdd(torch.nn.Module): 125*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 126*523fa7a6SAndroid Build Coastguard Worker super().__init__() 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 129*523fa7a6SAndroid Build Coastguard Worker c1, c2 = torch.chunk(x, chunks=2, dim=-1) 130*523fa7a6SAndroid Build Coastguard Worker return torch.add(c1, c2) 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Workerclass Clamp(torch.nn.Module): 134*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 135*523fa7a6SAndroid Build Coastguard Worker super().__init__() 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 138*523fa7a6SAndroid Build Coastguard Worker return torch.clamp(x, max=0) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Workerclass CompositeDelegateModule(torch.nn.Module): 142*523fa7a6SAndroid Build Coastguard Worker def __init__( 143*523fa7a6SAndroid Build Coastguard Worker self, 144*523fa7a6SAndroid Build Coastguard Worker compiler_specs, 145*523fa7a6SAndroid Build Coastguard Worker partitioner_type, 146*523fa7a6SAndroid Build Coastguard Worker capture_method, 147*523fa7a6SAndroid Build Coastguard Worker lowered_method, 148*523fa7a6SAndroid Build Coastguard Worker quantize_method=None, 149*523fa7a6SAndroid Build Coastguard Worker ) -> None: 150*523fa7a6SAndroid Build Coastguard Worker super().__init__() 151*523fa7a6SAndroid Build Coastguard Worker self.modules = [ 152*523fa7a6SAndroid Build Coastguard Worker Conv2dSequential(), 153*523fa7a6SAndroid Build Coastguard Worker Conv2dSequential(), 154*523fa7a6SAndroid Build Coastguard Worker Add(), 155*523fa7a6SAndroid Build Coastguard Worker Relu(), 156*523fa7a6SAndroid Build Coastguard Worker ] 157*523fa7a6SAndroid Build Coastguard Worker self.sample_inputs = [ 158*523fa7a6SAndroid Build Coastguard Worker (torch.randn([1, 1, 3, 3]),), 159*523fa7a6SAndroid Build Coastguard Worker (torch.randn([1, 1, 3, 3]),), 160*523fa7a6SAndroid Build Coastguard Worker (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])), 161*523fa7a6SAndroid Build Coastguard Worker (torch.randn([1, 2, 3, 3]),), 162*523fa7a6SAndroid Build Coastguard Worker ] 163*523fa7a6SAndroid Build Coastguard Worker self.lowered_modules = [] 164*523fa7a6SAndroid Build Coastguard Worker for module, sample_input in zip(self.modules, self.sample_inputs): 165*523fa7a6SAndroid Build Coastguard Worker partitioner = partitioner_type(compiler_specs) 166*523fa7a6SAndroid Build Coastguard Worker if quantize_method: 167*523fa7a6SAndroid Build Coastguard Worker module = quantize_method(module, sample_input) 168*523fa7a6SAndroid Build Coastguard Worker edge_prog = capture_method(module, sample_input) 169*523fa7a6SAndroid Build Coastguard Worker edge_prog.exported_program = lowered_method( 170*523fa7a6SAndroid Build Coastguard Worker edge_prog.exported_program, partitioner 171*523fa7a6SAndroid Build Coastguard Worker ) 172*523fa7a6SAndroid Build Coastguard Worker self.lowered_modules.append( 173*523fa7a6SAndroid Build Coastguard Worker edge_prog.exported_program.graph_module._modules.get("lowered_module_0") 174*523fa7a6SAndroid Build Coastguard Worker ) 175*523fa7a6SAndroid Build Coastguard Worker 176*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 177*523fa7a6SAndroid Build Coastguard Worker x1 = self.lowered_modules[0](x) 178*523fa7a6SAndroid Build Coastguard Worker x2 = self.lowered_modules[1](y) 179*523fa7a6SAndroid Build Coastguard Worker x3 = self.lowered_modules[2](x1[0], x2[0]) 180*523fa7a6SAndroid Build Coastguard Worker x4 = self.lowered_modules[3](x3[0]) 181*523fa7a6SAndroid Build Coastguard Worker return x4[0] 182*523fa7a6SAndroid Build Coastguard Worker 183*523fa7a6SAndroid Build Coastguard Worker def get_random_input(self): 184*523fa7a6SAndroid Build Coastguard Worker return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3])) 185*523fa7a6SAndroid Build Coastguard Worker 186*523fa7a6SAndroid Build Coastguard Worker def get_reference_module(self): 187*523fa7a6SAndroid Build Coastguard Worker class CompositeReferenceModule(torch.nn.Module): 188*523fa7a6SAndroid Build Coastguard Worker def __init__(self, modules): 189*523fa7a6SAndroid Build Coastguard Worker super().__init__() 190*523fa7a6SAndroid Build Coastguard Worker self.modules = modules 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 193*523fa7a6SAndroid Build Coastguard Worker x1 = self.modules[0](x) 194*523fa7a6SAndroid Build Coastguard Worker x2 = self.modules[1](y) 195*523fa7a6SAndroid Build Coastguard Worker x3 = self.modules[2](x1, x2) 196*523fa7a6SAndroid Build Coastguard Worker x4 = self.modules[3](x3) 197*523fa7a6SAndroid Build Coastguard Worker return x4 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker return CompositeReferenceModule(self.modules) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Workerclass ContextBinaryExample(torch.nn.Module): 203*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 204*523fa7a6SAndroid Build Coastguard Worker x = torch.nn.functional.relu(x) 205*523fa7a6SAndroid Build Coastguard Worker y = torch.nn.functional.relu(y) 206*523fa7a6SAndroid Build Coastguard Worker return x, y 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker def example_inputs(self): 209*523fa7a6SAndroid Build Coastguard Worker return { 210*523fa7a6SAndroid Build Coastguard Worker "x": torch.randn((1, 3, 3, 3)), 211*523fa7a6SAndroid Build Coastguard Worker "y": torch.randn((2, 1, 5, 5)), 212*523fa7a6SAndroid Build Coastguard Worker } 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker 215*523fa7a6SAndroid Build Coastguard Workerclass Conv1dSequential(torch.nn.Module): 216*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 217*523fa7a6SAndroid Build Coastguard Worker super().__init__() 218*523fa7a6SAndroid Build Coastguard Worker self.first = torch.nn.Conv1d( 219*523fa7a6SAndroid Build Coastguard Worker in_channels=1, 220*523fa7a6SAndroid Build Coastguard Worker out_channels=3, 221*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3), 222*523fa7a6SAndroid Build Coastguard Worker padding=1, 223*523fa7a6SAndroid Build Coastguard Worker bias=bias, 224*523fa7a6SAndroid Build Coastguard Worker ) 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker self.second = torch.nn.Conv1d( 227*523fa7a6SAndroid Build Coastguard Worker in_channels=3, 228*523fa7a6SAndroid Build Coastguard Worker out_channels=2, 229*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3), 230*523fa7a6SAndroid Build Coastguard Worker padding=1, 231*523fa7a6SAndroid Build Coastguard Worker bias=bias, 232*523fa7a6SAndroid Build Coastguard Worker ) 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 235*523fa7a6SAndroid Build Coastguard Worker return self.second(self.first(x)) 236*523fa7a6SAndroid Build Coastguard Worker 237*523fa7a6SAndroid Build Coastguard Worker 238*523fa7a6SAndroid Build Coastguard Worker# small models 239*523fa7a6SAndroid Build Coastguard Workerclass Conv1dReluLogSoftmax(torch.nn.Module): 240*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 241*523fa7a6SAndroid Build Coastguard Worker super().__init__() 242*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv1d( 243*523fa7a6SAndroid Build Coastguard Worker in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1 244*523fa7a6SAndroid Build Coastguard Worker ) 245*523fa7a6SAndroid Build Coastguard Worker self.logsoftmax = torch.nn.LogSoftmax(dim=1) 246*523fa7a6SAndroid Build Coastguard Worker 247*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 248*523fa7a6SAndroid Build Coastguard Worker x = torch.nn.functional.relu(self.conv(x)) 249*523fa7a6SAndroid Build Coastguard Worker x = self.logsoftmax(x) 250*523fa7a6SAndroid Build Coastguard Worker return x 251*523fa7a6SAndroid Build Coastguard Worker 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Workerclass Conv2dAvgPool2d(torch.nn.Module): 254*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 255*523fa7a6SAndroid Build Coastguard Worker super().__init__() 256*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 257*523fa7a6SAndroid Build Coastguard Worker 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 258*523fa7a6SAndroid Build Coastguard Worker ) 259*523fa7a6SAndroid Build Coastguard Worker self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1) 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 262*523fa7a6SAndroid Build Coastguard Worker return self.pool(self.conv(x)) 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Workerclass Conv2dBnHardtanhMean(torch.nn.Module): 266*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 267*523fa7a6SAndroid Build Coastguard Worker super(Conv2dBnHardtanhMean, self).__init__() 268*523fa7a6SAndroid Build Coastguard Worker groups = 1 269*523fa7a6SAndroid Build Coastguard Worker stride = [2, 2] 270*523fa7a6SAndroid Build Coastguard Worker padding = [1, 1] 271*523fa7a6SAndroid Build Coastguard Worker dilation = [1, 1] 272*523fa7a6SAndroid Build Coastguard Worker in_channels = 1 273*523fa7a6SAndroid Build Coastguard Worker out_channels = 1 274*523fa7a6SAndroid Build Coastguard Worker 275*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 276*523fa7a6SAndroid Build Coastguard Worker in_channels=in_channels, 277*523fa7a6SAndroid Build Coastguard Worker out_channels=out_channels, 278*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 279*523fa7a6SAndroid Build Coastguard Worker stride=stride, 280*523fa7a6SAndroid Build Coastguard Worker padding=padding, 281*523fa7a6SAndroid Build Coastguard Worker groups=groups, 282*523fa7a6SAndroid Build Coastguard Worker dilation=dilation, 283*523fa7a6SAndroid Build Coastguard Worker bias=True, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size())) 286*523fa7a6SAndroid Build Coastguard Worker self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 287*523fa7a6SAndroid Build Coastguard Worker self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 288*523fa7a6SAndroid Build Coastguard Worker self.eval() 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 291*523fa7a6SAndroid Build Coastguard Worker x1 = self.conv(x) 292*523fa7a6SAndroid Build Coastguard Worker x2 = self.native_batchnorm(x1) 293*523fa7a6SAndroid Build Coastguard Worker x3 = self.hardtanh(x2) 294*523fa7a6SAndroid Build Coastguard Worker x4 = torch.mean(x3, (1), keepdim=True) 295*523fa7a6SAndroid Build Coastguard Worker return x4 296*523fa7a6SAndroid Build Coastguard Worker 297*523fa7a6SAndroid Build Coastguard Worker 298*523fa7a6SAndroid Build Coastguard Workerclass Conv2dCat(torch.nn.Module): 299*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 300*523fa7a6SAndroid Build Coastguard Worker super().__init__() 301*523fa7a6SAndroid Build Coastguard Worker self.conv1 = torch.nn.Conv2d(3, 3, 3) 302*523fa7a6SAndroid Build Coastguard Worker self.conv2 = torch.nn.Conv2d(3, 3, 3) 303*523fa7a6SAndroid Build Coastguard Worker 304*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 305*523fa7a6SAndroid Build Coastguard Worker x = self.conv1(x) 306*523fa7a6SAndroid Build Coastguard Worker y = self.conv2(y) 307*523fa7a6SAndroid Build Coastguard Worker z = torch.cat([x, y], dim=1) 308*523fa7a6SAndroid Build Coastguard Worker return z 309*523fa7a6SAndroid Build Coastguard Worker 310*523fa7a6SAndroid Build Coastguard Worker 311*523fa7a6SAndroid Build Coastguard Workerclass Conv2dMaxPool2d(torch.nn.Module): 312*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 313*523fa7a6SAndroid Build Coastguard Worker super().__init__() 314*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 315*523fa7a6SAndroid Build Coastguard Worker in_channels=2, 316*523fa7a6SAndroid Build Coastguard Worker out_channels=2, 317*523fa7a6SAndroid Build Coastguard Worker kernel_size=(1, 1), 318*523fa7a6SAndroid Build Coastguard Worker padding=1, 319*523fa7a6SAndroid Build Coastguard Worker bias=True, 320*523fa7a6SAndroid Build Coastguard Worker ) 321*523fa7a6SAndroid Build Coastguard Worker self.pool = torch.nn.MaxPool2d(1, 1) 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 324*523fa7a6SAndroid Build Coastguard Worker return self.pool(self.conv(x)) 325*523fa7a6SAndroid Build Coastguard Worker 326*523fa7a6SAndroid Build Coastguard Worker 327*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSequential(torch.nn.Module): 328*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 329*523fa7a6SAndroid Build Coastguard Worker super().__init__() 330*523fa7a6SAndroid Build Coastguard Worker self.first = torch.nn.Conv2d( 331*523fa7a6SAndroid Build Coastguard Worker in_channels=1, 332*523fa7a6SAndroid Build Coastguard Worker out_channels=3, 333*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 334*523fa7a6SAndroid Build Coastguard Worker padding=1, 335*523fa7a6SAndroid Build Coastguard Worker bias=bias, 336*523fa7a6SAndroid Build Coastguard Worker ) 337*523fa7a6SAndroid Build Coastguard Worker self.second = torch.nn.Conv2d( 338*523fa7a6SAndroid Build Coastguard Worker in_channels=3, 339*523fa7a6SAndroid Build Coastguard Worker out_channels=2, 340*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 341*523fa7a6SAndroid Build Coastguard Worker padding=1, 342*523fa7a6SAndroid Build Coastguard Worker bias=bias, 343*523fa7a6SAndroid Build Coastguard Worker ) 344*523fa7a6SAndroid Build Coastguard Worker 345*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 346*523fa7a6SAndroid Build Coastguard Worker return self.second(self.first(x)) 347*523fa7a6SAndroid Build Coastguard Worker 348*523fa7a6SAndroid Build Coastguard Worker 349*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSingle(torch.nn.Module): 350*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 351*523fa7a6SAndroid Build Coastguard Worker super().__init__() 352*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 353*523fa7a6SAndroid Build Coastguard Worker in_channels=1, 354*523fa7a6SAndroid Build Coastguard Worker out_channels=3, 355*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 356*523fa7a6SAndroid Build Coastguard Worker padding=1, 357*523fa7a6SAndroid Build Coastguard Worker bias=bias, 358*523fa7a6SAndroid Build Coastguard Worker ) 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 361*523fa7a6SAndroid Build Coastguard Worker return self.conv(x) 362*523fa7a6SAndroid Build Coastguard Worker 363*523fa7a6SAndroid Build Coastguard Worker 364*523fa7a6SAndroid Build Coastguard Workerclass ConvTranspose2dSingle(torch.nn.Module): 365*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 366*523fa7a6SAndroid Build Coastguard Worker super().__init__() 367*523fa7a6SAndroid Build Coastguard Worker self.conv_transpose = torch.nn.ConvTranspose2d( 368*523fa7a6SAndroid Build Coastguard Worker in_channels=1, 369*523fa7a6SAndroid Build Coastguard Worker out_channels=3, 370*523fa7a6SAndroid Build Coastguard Worker kernel_size=3, 371*523fa7a6SAndroid Build Coastguard Worker stride=2, 372*523fa7a6SAndroid Build Coastguard Worker padding=1, 373*523fa7a6SAndroid Build Coastguard Worker bias=bias, 374*523fa7a6SAndroid Build Coastguard Worker ) 375*523fa7a6SAndroid Build Coastguard Worker 376*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 377*523fa7a6SAndroid Build Coastguard Worker return self.conv_transpose(x) 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Worker 380*523fa7a6SAndroid Build Coastguard Workerclass Conv2dDownUpSample(torch.nn.Module): 381*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 382*523fa7a6SAndroid Build Coastguard Worker super().__init__() 383*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 384*523fa7a6SAndroid Build Coastguard Worker in_channels=16, 385*523fa7a6SAndroid Build Coastguard Worker out_channels=16, 386*523fa7a6SAndroid Build Coastguard Worker kernel_size=3, 387*523fa7a6SAndroid Build Coastguard Worker stride=2, 388*523fa7a6SAndroid Build Coastguard Worker padding=1, 389*523fa7a6SAndroid Build Coastguard Worker bias=bias, 390*523fa7a6SAndroid Build Coastguard Worker ) 391*523fa7a6SAndroid Build Coastguard Worker self.conv_transpose = torch.nn.ConvTranspose2d( 392*523fa7a6SAndroid Build Coastguard Worker in_channels=16, 393*523fa7a6SAndroid Build Coastguard Worker out_channels=16, 394*523fa7a6SAndroid Build Coastguard Worker kernel_size=3, 395*523fa7a6SAndroid Build Coastguard Worker stride=2, 396*523fa7a6SAndroid Build Coastguard Worker padding=1, 397*523fa7a6SAndroid Build Coastguard Worker bias=bias, 398*523fa7a6SAndroid Build Coastguard Worker ) 399*523fa7a6SAndroid Build Coastguard Worker 400*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 401*523fa7a6SAndroid Build Coastguard Worker return self.conv_transpose(self.conv(x)) 402*523fa7a6SAndroid Build Coastguard Worker 403*523fa7a6SAndroid Build Coastguard Worker 404*523fa7a6SAndroid Build Coastguard Workerclass Conv2dSumReduceDim(torch.nn.Module): 405*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 406*523fa7a6SAndroid Build Coastguard Worker super().__init__() 407*523fa7a6SAndroid Build Coastguard Worker self.first = torch.nn.Conv2d( 408*523fa7a6SAndroid Build Coastguard Worker in_channels=1, 409*523fa7a6SAndroid Build Coastguard Worker out_channels=3, 410*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 411*523fa7a6SAndroid Build Coastguard Worker padding=1, 412*523fa7a6SAndroid Build Coastguard Worker bias=True, 413*523fa7a6SAndroid Build Coastguard Worker ) 414*523fa7a6SAndroid Build Coastguard Worker 415*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 416*523fa7a6SAndroid Build Coastguard Worker return torch.sum(self.first(x), dim=(2, 3), keepdim=False) 417*523fa7a6SAndroid Build Coastguard Worker 418*523fa7a6SAndroid Build Coastguard Worker 419*523fa7a6SAndroid Build Coastguard Workerclass Conv2dTopK(torch.nn.Module): 420*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 421*523fa7a6SAndroid Build Coastguard Worker super().__init__() 422*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 16, 3) 423*523fa7a6SAndroid Build Coastguard Worker 424*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 425*523fa7a6SAndroid Build Coastguard Worker x = self.conv(x) 426*523fa7a6SAndroid Build Coastguard Worker topk_values, topk_indices = torch.topk(x, 5, dim=1) 427*523fa7a6SAndroid Build Coastguard Worker return topk_values 428*523fa7a6SAndroid Build Coastguard Worker 429*523fa7a6SAndroid Build Coastguard Worker 430*523fa7a6SAndroid Build Coastguard Workerclass Div(torch.nn.Module): 431*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 432*523fa7a6SAndroid Build Coastguard Worker super().__init__() 433*523fa7a6SAndroid Build Coastguard Worker 434*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 435*523fa7a6SAndroid Build Coastguard Worker return torch.divide(x, y) 436*523fa7a6SAndroid Build Coastguard Worker 437*523fa7a6SAndroid Build Coastguard Worker 438*523fa7a6SAndroid Build Coastguard Workerclass DivConstantFloat(torch.nn.Module): 439*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 440*523fa7a6SAndroid Build Coastguard Worker super().__init__() 441*523fa7a6SAndroid Build Coastguard Worker 442*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 443*523fa7a6SAndroid Build Coastguard Worker return x / 10.0 444*523fa7a6SAndroid Build Coastguard Worker 445*523fa7a6SAndroid Build Coastguard Worker 446*523fa7a6SAndroid Build Coastguard Workerclass DivConstantLong(torch.nn.Module): 447*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 448*523fa7a6SAndroid Build Coastguard Worker super().__init__() 449*523fa7a6SAndroid Build Coastguard Worker 450*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 451*523fa7a6SAndroid Build Coastguard Worker return x / 10 452*523fa7a6SAndroid Build Coastguard Worker 453*523fa7a6SAndroid Build Coastguard Worker 454*523fa7a6SAndroid Build Coastguard Workerclass EinsumBilinear(torch.nn.Module): 455*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 456*523fa7a6SAndroid Build Coastguard Worker super().__init__() 457*523fa7a6SAndroid Build Coastguard Worker 458*523fa7a6SAndroid Build Coastguard Worker def forward(self, bn, anm, bm): 459*523fa7a6SAndroid Build Coastguard Worker return torch.einsum("bn,anm,bm->ba", bn, anm, bm) 460*523fa7a6SAndroid Build Coastguard Worker 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Workerclass EinsumOuterProduct(torch.nn.Module): 463*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 464*523fa7a6SAndroid Build Coastguard Worker super().__init__() 465*523fa7a6SAndroid Build Coastguard Worker 466*523fa7a6SAndroid Build Coastguard Worker def forward(self, i, j): 467*523fa7a6SAndroid Build Coastguard Worker return torch.einsum("i,j->ij", i, j) 468*523fa7a6SAndroid Build Coastguard Worker 469*523fa7a6SAndroid Build Coastguard Worker 470*523fa7a6SAndroid Build Coastguard Workerclass EinsumOuterProductRelu(torch.nn.Module): 471*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 472*523fa7a6SAndroid Build Coastguard Worker super().__init__() 473*523fa7a6SAndroid Build Coastguard Worker 474*523fa7a6SAndroid Build Coastguard Worker def forward(self, i, j): 475*523fa7a6SAndroid Build Coastguard Worker return torch.relu(torch.einsum("i,j->ij", i, j)) 476*523fa7a6SAndroid Build Coastguard Worker 477*523fa7a6SAndroid Build Coastguard Worker 478*523fa7a6SAndroid Build Coastguard Workerclass Embedding(torch.nn.Module): 479*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 480*523fa7a6SAndroid Build Coastguard Worker super().__init__() 481*523fa7a6SAndroid Build Coastguard Worker self.embedding = torch.nn.Embedding(10, 3) 482*523fa7a6SAndroid Build Coastguard Worker 483*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 484*523fa7a6SAndroid Build Coastguard Worker return self.embedding(x) 485*523fa7a6SAndroid Build Coastguard Worker 486*523fa7a6SAndroid Build Coastguard Worker 487*523fa7a6SAndroid Build Coastguard Workerclass ExpandCopy(torch.nn.Module): 488*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 489*523fa7a6SAndroid Build Coastguard Worker super().__init__() 490*523fa7a6SAndroid Build Coastguard Worker 491*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 492*523fa7a6SAndroid Build Coastguard Worker return x.expand(3, 4) 493*523fa7a6SAndroid Build Coastguard Worker 494*523fa7a6SAndroid Build Coastguard Worker 495*523fa7a6SAndroid Build Coastguard Workerclass Gelu(torch.nn.Module): 496*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 497*523fa7a6SAndroid Build Coastguard Worker super().__init__() 498*523fa7a6SAndroid Build Coastguard Worker self.gelu = torch.nn.GELU() 499*523fa7a6SAndroid Build Coastguard Worker 500*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 501*523fa7a6SAndroid Build Coastguard Worker return self.gelu(x) 502*523fa7a6SAndroid Build Coastguard Worker 503*523fa7a6SAndroid Build Coastguard Worker 504*523fa7a6SAndroid Build Coastguard Workerclass GroupNorm(torch.nn.Module): 505*523fa7a6SAndroid Build Coastguard Worker def __init__(self, bias=True): 506*523fa7a6SAndroid Build Coastguard Worker super().__init__() 507*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 508*523fa7a6SAndroid Build Coastguard Worker 32, 509*523fa7a6SAndroid Build Coastguard Worker 256, 510*523fa7a6SAndroid Build Coastguard Worker kernel_size=3, 511*523fa7a6SAndroid Build Coastguard Worker stride=1, 512*523fa7a6SAndroid Build Coastguard Worker padding=1, 513*523fa7a6SAndroid Build Coastguard Worker bias=bias, 514*523fa7a6SAndroid Build Coastguard Worker ) 515*523fa7a6SAndroid Build Coastguard Worker self.norm = torch.nn.GroupNorm(32, 256) 516*523fa7a6SAndroid Build Coastguard Worker 517*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 518*523fa7a6SAndroid Build Coastguard Worker y = self.conv(x) 519*523fa7a6SAndroid Build Coastguard Worker return y, self.norm(y) 520*523fa7a6SAndroid Build Coastguard Worker 521*523fa7a6SAndroid Build Coastguard Worker 522*523fa7a6SAndroid Build Coastguard Workerclass HardSigmoid(torch.nn.Module): 523*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 524*523fa7a6SAndroid Build Coastguard Worker super().__init__() 525*523fa7a6SAndroid Build Coastguard Worker self.hardsigmoid = torch.nn.Hardsigmoid() 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 528*523fa7a6SAndroid Build Coastguard Worker return self.hardsigmoid(x) 529*523fa7a6SAndroid Build Coastguard Worker 530*523fa7a6SAndroid Build Coastguard Worker 531*523fa7a6SAndroid Build Coastguard Workerclass HardSwish(torch.nn.Module): 532*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 533*523fa7a6SAndroid Build Coastguard Worker super().__init__() 534*523fa7a6SAndroid Build Coastguard Worker self.hardswish = torch.nn.Hardswish() 535*523fa7a6SAndroid Build Coastguard Worker 536*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 537*523fa7a6SAndroid Build Coastguard Worker return self.hardswish(x) 538*523fa7a6SAndroid Build Coastguard Worker 539*523fa7a6SAndroid Build Coastguard Worker 540*523fa7a6SAndroid Build Coastguard Workerclass HardTanh(torch.nn.Module): 541*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 542*523fa7a6SAndroid Build Coastguard Worker super().__init__() 543*523fa7a6SAndroid Build Coastguard Worker self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 544*523fa7a6SAndroid Build Coastguard Worker 545*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 546*523fa7a6SAndroid Build Coastguard Worker return self.hardtanh(x) 547*523fa7a6SAndroid Build Coastguard Worker 548*523fa7a6SAndroid Build Coastguard Worker 549*523fa7a6SAndroid Build Coastguard Workerclass Index(torch.nn.Module): 550*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 551*523fa7a6SAndroid Build Coastguard Worker super().__init__() 552*523fa7a6SAndroid Build Coastguard Worker self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32) 553*523fa7a6SAndroid Build Coastguard Worker self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) 554*523fa7a6SAndroid Build Coastguard Worker 555*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 556*523fa7a6SAndroid Build Coastguard Worker return x[self.idx0] + x[self.idx1] 557*523fa7a6SAndroid Build Coastguard Worker 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Workerclass IndexPut(torch.nn.Module): 560*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 561*523fa7a6SAndroid Build Coastguard Worker super().__init__() 562*523fa7a6SAndroid Build Coastguard Worker self.register_buffer( 563*523fa7a6SAndroid Build Coastguard Worker "k_cache", 564*523fa7a6SAndroid Build Coastguard Worker torch.zeros((1, 1024, 12, 64), dtype=torch.float32), 565*523fa7a6SAndroid Build Coastguard Worker ) 566*523fa7a6SAndroid Build Coastguard Worker 567*523fa7a6SAndroid Build Coastguard Worker def forward(self, input_pos, k_val): 568*523fa7a6SAndroid Build Coastguard Worker k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) 569*523fa7a6SAndroid Build Coastguard Worker return k_out 570*523fa7a6SAndroid Build Coastguard Worker 571*523fa7a6SAndroid Build Coastguard Worker 572*523fa7a6SAndroid Build Coastguard Workerclass LayerNorm(torch.nn.Module): 573*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 574*523fa7a6SAndroid Build Coastguard Worker super().__init__() 575*523fa7a6SAndroid Build Coastguard Worker self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) 576*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(768, 196) 577*523fa7a6SAndroid Build Coastguard Worker 578*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 579*523fa7a6SAndroid Build Coastguard Worker return self.linear(self.layer_norm(x)) 580*523fa7a6SAndroid Build Coastguard Worker 581*523fa7a6SAndroid Build Coastguard Worker 582*523fa7a6SAndroid Build Coastguard Workerclass LeakyReLUDefault(torch.nn.Module): 583*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 584*523fa7a6SAndroid Build Coastguard Worker super().__init__() 585*523fa7a6SAndroid Build Coastguard Worker self.leaky_relu = torch.nn.LeakyReLU() 586*523fa7a6SAndroid Build Coastguard Worker 587*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 588*523fa7a6SAndroid Build Coastguard Worker return self.leaky_relu(x) 589*523fa7a6SAndroid Build Coastguard Worker 590*523fa7a6SAndroid Build Coastguard Worker 591*523fa7a6SAndroid Build Coastguard Workerclass LeakyReLUCustom(torch.nn.Module): 592*523fa7a6SAndroid Build Coastguard Worker def __init__(self, coeff): 593*523fa7a6SAndroid Build Coastguard Worker super().__init__() 594*523fa7a6SAndroid Build Coastguard Worker self.leaky_relu = torch.nn.LeakyReLU(coeff) 595*523fa7a6SAndroid Build Coastguard Worker 596*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 597*523fa7a6SAndroid Build Coastguard Worker return self.leaky_relu(x) 598*523fa7a6SAndroid Build Coastguard Worker 599*523fa7a6SAndroid Build Coastguard Worker 600*523fa7a6SAndroid Build Coastguard Workerclass Linear(torch.nn.Module): 601*523fa7a6SAndroid Build Coastguard Worker def __init__(self, use_bias: bool = True): 602*523fa7a6SAndroid Build Coastguard Worker super().__init__() 603*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5, use_bias).eval() 604*523fa7a6SAndroid Build Coastguard Worker 605*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 606*523fa7a6SAndroid Build Coastguard Worker return self.linear(x) 607*523fa7a6SAndroid Build Coastguard Worker 608*523fa7a6SAndroid Build Coastguard Worker 609*523fa7a6SAndroid Build Coastguard Workerclass LogSoftmax(torch.nn.Module): 610*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 611*523fa7a6SAndroid Build Coastguard Worker super().__init__() 612*523fa7a6SAndroid Build Coastguard Worker 613*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 614*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.log_softmax(x, dim=-1) 615*523fa7a6SAndroid Build Coastguard Worker 616*523fa7a6SAndroid Build Coastguard Worker 617*523fa7a6SAndroid Build Coastguard Workerclass MaxPool2d(torch.nn.Module): 618*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 619*523fa7a6SAndroid Build Coastguard Worker super().__init__() 620*523fa7a6SAndroid Build Coastguard Worker self.max_pool2d = torch.nn.MaxPool2d( 621*523fa7a6SAndroid Build Coastguard Worker kernel_size=3, 622*523fa7a6SAndroid Build Coastguard Worker stride=1, 623*523fa7a6SAndroid Build Coastguard Worker padding=1, 624*523fa7a6SAndroid Build Coastguard Worker dilation=1, 625*523fa7a6SAndroid Build Coastguard Worker ceil_mode=True, 626*523fa7a6SAndroid Build Coastguard Worker ) 627*523fa7a6SAndroid Build Coastguard Worker 628*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 629*523fa7a6SAndroid Build Coastguard Worker return self.max_pool2d(x) 630*523fa7a6SAndroid Build Coastguard Worker 631*523fa7a6SAndroid Build Coastguard Worker 632*523fa7a6SAndroid Build Coastguard Workerclass MeanWKeppDim(torch.nn.Module): 633*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 634*523fa7a6SAndroid Build Coastguard Worker super().__init__() 635*523fa7a6SAndroid Build Coastguard Worker 636*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 637*523fa7a6SAndroid Build Coastguard Worker return torch.mean(x, (-1, -2), keepdim=True) 638*523fa7a6SAndroid Build Coastguard Worker 639*523fa7a6SAndroid Build Coastguard Worker 640*523fa7a6SAndroid Build Coastguard Workerclass MeanWOKeppDim(torch.nn.Module): 641*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 642*523fa7a6SAndroid Build Coastguard Worker super().__init__() 643*523fa7a6SAndroid Build Coastguard Worker 644*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 645*523fa7a6SAndroid Build Coastguard Worker return torch.mean(x, (-1, -2)) 646*523fa7a6SAndroid Build Coastguard Worker 647*523fa7a6SAndroid Build Coastguard Worker 648*523fa7a6SAndroid Build Coastguard Workerclass Mul(torch.nn.Module): 649*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 650*523fa7a6SAndroid Build Coastguard Worker super().__init__() 651*523fa7a6SAndroid Build Coastguard Worker 652*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 653*523fa7a6SAndroid Build Coastguard Worker return torch.mul(x, y) 654*523fa7a6SAndroid Build Coastguard Worker 655*523fa7a6SAndroid Build Coastguard Worker 656*523fa7a6SAndroid Build Coastguard Workerclass MulConstantFloat(torch.nn.Module): 657*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 658*523fa7a6SAndroid Build Coastguard Worker super().__init__() 659*523fa7a6SAndroid Build Coastguard Worker 660*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 661*523fa7a6SAndroid Build Coastguard Worker return 10.0 * x 662*523fa7a6SAndroid Build Coastguard Worker 663*523fa7a6SAndroid Build Coastguard Worker 664*523fa7a6SAndroid Build Coastguard Workerclass MulConstantLong(torch.nn.Module): 665*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 666*523fa7a6SAndroid Build Coastguard Worker super().__init__() 667*523fa7a6SAndroid Build Coastguard Worker 668*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 669*523fa7a6SAndroid Build Coastguard Worker return 10 * x 670*523fa7a6SAndroid Build Coastguard Worker 671*523fa7a6SAndroid Build Coastguard Worker 672*523fa7a6SAndroid Build Coastguard Workerclass MulScalar(torch.nn.Module): 673*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 674*523fa7a6SAndroid Build Coastguard Worker super().__init__() 675*523fa7a6SAndroid Build Coastguard Worker self._scalar = 3.14 676*523fa7a6SAndroid Build Coastguard Worker 677*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 678*523fa7a6SAndroid Build Coastguard Worker out1 = torch.ops.aten.mul.Scalar(x, self._scalar) 679*523fa7a6SAndroid Build Coastguard Worker return out1 680*523fa7a6SAndroid Build Coastguard Worker 681*523fa7a6SAndroid Build Coastguard Worker 682*523fa7a6SAndroid Build Coastguard Workerclass MultiheadAttention(torch.nn.Module): 683*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 684*523fa7a6SAndroid Build Coastguard Worker super().__init__() 685*523fa7a6SAndroid Build Coastguard Worker self.multi_head_attention = torch.nn.MultiheadAttention( 686*523fa7a6SAndroid Build Coastguard Worker 96, 12, dropout=0.0, batch_first=True 687*523fa7a6SAndroid Build Coastguard Worker ) 688*523fa7a6SAndroid Build Coastguard Worker 689*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 690*523fa7a6SAndroid Build Coastguard Worker attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False) 691*523fa7a6SAndroid Build Coastguard Worker return attn_output 692*523fa7a6SAndroid Build Coastguard Worker 693*523fa7a6SAndroid Build Coastguard Worker 694*523fa7a6SAndroid Build Coastguard Workerclass Pad(torch.nn.Module): 695*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 696*523fa7a6SAndroid Build Coastguard Worker super().__init__() 697*523fa7a6SAndroid Build Coastguard Worker 698*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 699*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.pad( 700*523fa7a6SAndroid Build Coastguard Worker x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant" 701*523fa7a6SAndroid Build Coastguard Worker ) 702*523fa7a6SAndroid Build Coastguard Worker 703*523fa7a6SAndroid Build Coastguard Worker 704*523fa7a6SAndroid Build Coastguard Workerclass PixelShuffle(torch.nn.Module): 705*523fa7a6SAndroid Build Coastguard Worker def __init__(self, scale): 706*523fa7a6SAndroid Build Coastguard Worker super().__init__() 707*523fa7a6SAndroid Build Coastguard Worker self.pixel_shuffle = torch.nn.PixelShuffle(scale) 708*523fa7a6SAndroid Build Coastguard Worker 709*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 710*523fa7a6SAndroid Build Coastguard Worker return self.pixel_shuffle(x) 711*523fa7a6SAndroid Build Coastguard Worker 712*523fa7a6SAndroid Build Coastguard Worker 713*523fa7a6SAndroid Build Coastguard Workerclass PixelUnshuffle(torch.nn.Module): 714*523fa7a6SAndroid Build Coastguard Worker def __init__(self, scale): 715*523fa7a6SAndroid Build Coastguard Worker super().__init__() 716*523fa7a6SAndroid Build Coastguard Worker self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale) 717*523fa7a6SAndroid Build Coastguard Worker 718*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 719*523fa7a6SAndroid Build Coastguard Worker return self.pixel_unshuffle(x) 720*523fa7a6SAndroid Build Coastguard Worker 721*523fa7a6SAndroid Build Coastguard Worker 722*523fa7a6SAndroid Build Coastguard Workerclass PixelUnshuffleMathEquivalent(torch.nn.Module): 723*523fa7a6SAndroid Build Coastguard Worker def __init__(self, scale): 724*523fa7a6SAndroid Build Coastguard Worker super().__init__() 725*523fa7a6SAndroid Build Coastguard Worker self.scale = scale 726*523fa7a6SAndroid Build Coastguard Worker 727*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 728*523fa7a6SAndroid Build Coastguard Worker b, c, hh, hw = x.size() 729*523fa7a6SAndroid Build Coastguard Worker out_channel = c * (self.scale**2) 730*523fa7a6SAndroid Build Coastguard Worker h = hh // self.scale 731*523fa7a6SAndroid Build Coastguard Worker w = hw // self.scale 732*523fa7a6SAndroid Build Coastguard Worker x_view = x.view(b, c, h, self.scale, w, self.scale) 733*523fa7a6SAndroid Build Coastguard Worker return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 734*523fa7a6SAndroid Build Coastguard Worker 735*523fa7a6SAndroid Build Coastguard Worker 736*523fa7a6SAndroid Build Coastguard Workerclass PowTensorScalar(torch.nn.Module): 737*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 738*523fa7a6SAndroid Build Coastguard Worker super().__init__() 739*523fa7a6SAndroid Build Coastguard Worker 740*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 741*523fa7a6SAndroid Build Coastguard Worker return torch.pow(x, 2) 742*523fa7a6SAndroid Build Coastguard Worker 743*523fa7a6SAndroid Build Coastguard Worker 744*523fa7a6SAndroid Build Coastguard Workerclass PReLUDefault(torch.nn.Module): 745*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 746*523fa7a6SAndroid Build Coastguard Worker super().__init__() 747*523fa7a6SAndroid Build Coastguard Worker self.prelu = torch.nn.PReLU() 748*523fa7a6SAndroid Build Coastguard Worker 749*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 750*523fa7a6SAndroid Build Coastguard Worker return self.prelu(x) 751*523fa7a6SAndroid Build Coastguard Worker 752*523fa7a6SAndroid Build Coastguard Worker 753*523fa7a6SAndroid Build Coastguard Workerclass PReLUPerChannel(torch.nn.Module): 754*523fa7a6SAndroid Build Coastguard Worker def __init__(self, channels): 755*523fa7a6SAndroid Build Coastguard Worker super().__init__() 756*523fa7a6SAndroid Build Coastguard Worker self.prelu = torch.nn.PReLU(channels) 757*523fa7a6SAndroid Build Coastguard Worker 758*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 759*523fa7a6SAndroid Build Coastguard Worker return self.prelu(x) 760*523fa7a6SAndroid Build Coastguard Worker 761*523fa7a6SAndroid Build Coastguard Worker 762*523fa7a6SAndroid Build Coastguard Workerclass Relu(torch.nn.Module): 763*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 764*523fa7a6SAndroid Build Coastguard Worker super().__init__() 765*523fa7a6SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 766*523fa7a6SAndroid Build Coastguard Worker 767*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 768*523fa7a6SAndroid Build Coastguard Worker return self.relu(x) 769*523fa7a6SAndroid Build Coastguard Worker 770*523fa7a6SAndroid Build Coastguard Worker 771*523fa7a6SAndroid Build Coastguard Workerclass Reshape(torch.nn.Module): 772*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 773*523fa7a6SAndroid Build Coastguard Worker super().__init__() 774*523fa7a6SAndroid Build Coastguard Worker 775*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 776*523fa7a6SAndroid Build Coastguard Worker return x.reshape(1, 12) 777*523fa7a6SAndroid Build Coastguard Worker 778*523fa7a6SAndroid Build Coastguard Worker 779*523fa7a6SAndroid Build Coastguard Workerclass ResidualBlockModule(torch.nn.Module): 780*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 781*523fa7a6SAndroid Build Coastguard Worker super(ResidualBlockModule, self).__init__() 782*523fa7a6SAndroid Build Coastguard Worker groups = 1 783*523fa7a6SAndroid Build Coastguard Worker stride = [1, 1] 784*523fa7a6SAndroid Build Coastguard Worker padding = [1, 1] 785*523fa7a6SAndroid Build Coastguard Worker dilation = [1, 1] 786*523fa7a6SAndroid Build Coastguard Worker in_channels = 32 787*523fa7a6SAndroid Build Coastguard Worker out_channels = 32 788*523fa7a6SAndroid Build Coastguard Worker 789*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 790*523fa7a6SAndroid Build Coastguard Worker in_channels=in_channels, 791*523fa7a6SAndroid Build Coastguard Worker out_channels=out_channels, 792*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 793*523fa7a6SAndroid Build Coastguard Worker stride=stride, 794*523fa7a6SAndroid Build Coastguard Worker padding=padding, 795*523fa7a6SAndroid Build Coastguard Worker groups=groups, 796*523fa7a6SAndroid Build Coastguard Worker dilation=dilation, 797*523fa7a6SAndroid Build Coastguard Worker bias=True, 798*523fa7a6SAndroid Build Coastguard Worker ) 799*523fa7a6SAndroid Build Coastguard Worker self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 800*523fa7a6SAndroid Build Coastguard Worker self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0) 801*523fa7a6SAndroid Build Coastguard Worker self.eval() 802*523fa7a6SAndroid Build Coastguard Worker 803*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 804*523fa7a6SAndroid Build Coastguard Worker x1 = self.conv(x) 805*523fa7a6SAndroid Build Coastguard Worker x2 = self.native_batchnorm(x1) 806*523fa7a6SAndroid Build Coastguard Worker x3 = self.conv(x2) 807*523fa7a6SAndroid Build Coastguard Worker x4 = self.native_batchnorm(x3) 808*523fa7a6SAndroid Build Coastguard Worker x5 = self.hardtanh(x4) 809*523fa7a6SAndroid Build Coastguard Worker x6 = torch.add(x5, x2) 810*523fa7a6SAndroid Build Coastguard Worker return x6 811*523fa7a6SAndroid Build Coastguard Worker 812*523fa7a6SAndroid Build Coastguard Worker 813*523fa7a6SAndroid Build Coastguard Workerclass ResizeBilinear2D(torch.nn.Module): 814*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 815*523fa7a6SAndroid Build Coastguard Worker super().__init__() 816*523fa7a6SAndroid Build Coastguard Worker 817*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 818*523fa7a6SAndroid Build Coastguard Worker output_shape = [dim * 2 for dim in x.shape[-2:]] 819*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.interpolate( 820*523fa7a6SAndroid Build Coastguard Worker x, 821*523fa7a6SAndroid Build Coastguard Worker size=list(torch.randn(output_shape).shape), 822*523fa7a6SAndroid Build Coastguard Worker mode="bilinear", 823*523fa7a6SAndroid Build Coastguard Worker align_corners=False, 824*523fa7a6SAndroid Build Coastguard Worker ) 825*523fa7a6SAndroid Build Coastguard Worker 826*523fa7a6SAndroid Build Coastguard Worker 827*523fa7a6SAndroid Build Coastguard Workerclass ResizeNearest2D(torch.nn.Module): 828*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 829*523fa7a6SAndroid Build Coastguard Worker super().__init__() 830*523fa7a6SAndroid Build Coastguard Worker 831*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 832*523fa7a6SAndroid Build Coastguard Worker output_shape = [dim * 2 for dim in x.shape[-2:]] 833*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.interpolate( 834*523fa7a6SAndroid Build Coastguard Worker x, 835*523fa7a6SAndroid Build Coastguard Worker size=list(torch.randn(output_shape).shape), 836*523fa7a6SAndroid Build Coastguard Worker mode="nearest", 837*523fa7a6SAndroid Build Coastguard Worker ) 838*523fa7a6SAndroid Build Coastguard Worker 839*523fa7a6SAndroid Build Coastguard Worker 840*523fa7a6SAndroid Build Coastguard Workerclass RmsNorm(torch.nn.Module): 841*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 842*523fa7a6SAndroid Build Coastguard Worker super().__init__() 843*523fa7a6SAndroid Build Coastguard Worker self.eps = 1e-5 844*523fa7a6SAndroid Build Coastguard Worker self.rms = torch.nn.RMSNorm([4], 1e-5) 845*523fa7a6SAndroid Build Coastguard Worker 846*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 847*523fa7a6SAndroid Build Coastguard Worker return self.rms(x) 848*523fa7a6SAndroid Build Coastguard Worker 849*523fa7a6SAndroid Build Coastguard Worker 850*523fa7a6SAndroid Build Coastguard Workerclass Rsqrt(torch.nn.Module): 851*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 852*523fa7a6SAndroid Build Coastguard Worker super().__init__() 853*523fa7a6SAndroid Build Coastguard Worker 854*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 855*523fa7a6SAndroid Build Coastguard Worker return torch.rsqrt(x) 856*523fa7a6SAndroid Build Coastguard Worker 857*523fa7a6SAndroid Build Coastguard Worker 858*523fa7a6SAndroid Build Coastguard Workerclass ScaledDotProductAttention(torch.nn.Module): 859*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 860*523fa7a6SAndroid Build Coastguard Worker super().__init__() 861*523fa7a6SAndroid Build Coastguard Worker 862*523fa7a6SAndroid Build Coastguard Worker def forward(self, query_layer, key_layer, value_layer, attn_mask): 863*523fa7a6SAndroid Build Coastguard Worker attn_output = torch.nn.functional.scaled_dot_product_attention( 864*523fa7a6SAndroid Build Coastguard Worker query_layer, key_layer, value_layer, attn_mask 865*523fa7a6SAndroid Build Coastguard Worker ) 866*523fa7a6SAndroid Build Coastguard Worker return attn_output 867*523fa7a6SAndroid Build Coastguard Worker 868*523fa7a6SAndroid Build Coastguard Worker 869*523fa7a6SAndroid Build Coastguard Workerclass SelectCopy(torch.nn.Module): 870*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 871*523fa7a6SAndroid Build Coastguard Worker super().__init__() 872*523fa7a6SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 873*523fa7a6SAndroid Build Coastguard Worker in_channels=3, 874*523fa7a6SAndroid Build Coastguard Worker out_channels=2, 875*523fa7a6SAndroid Build Coastguard Worker kernel_size=(3, 3), 876*523fa7a6SAndroid Build Coastguard Worker padding=1, 877*523fa7a6SAndroid Build Coastguard Worker bias=True, 878*523fa7a6SAndroid Build Coastguard Worker ) 879*523fa7a6SAndroid Build Coastguard Worker 880*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 881*523fa7a6SAndroid Build Coastguard Worker return self.conv(x)[0, 1, 1:2] 882*523fa7a6SAndroid Build Coastguard Worker 883*523fa7a6SAndroid Build Coastguard Worker 884*523fa7a6SAndroid Build Coastguard Workerclass Sigmoid(torch.nn.Module): 885*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 886*523fa7a6SAndroid Build Coastguard Worker super().__init__() 887*523fa7a6SAndroid Build Coastguard Worker 888*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 889*523fa7a6SAndroid Build Coastguard Worker return torch.sigmoid(x) 890*523fa7a6SAndroid Build Coastguard Worker 891*523fa7a6SAndroid Build Coastguard Worker 892*523fa7a6SAndroid Build Coastguard Workerclass SimpleModel(torch.nn.Module): 893*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 894*523fa7a6SAndroid Build Coastguard Worker super().__init__() 895*523fa7a6SAndroid Build Coastguard Worker kernel_sz = 32 896*523fa7a6SAndroid Build Coastguard Worker self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) 897*523fa7a6SAndroid Build Coastguard Worker self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) 898*523fa7a6SAndroid Build Coastguard Worker self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) 899*523fa7a6SAndroid Build Coastguard Worker self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) 900*523fa7a6SAndroid Build Coastguard Worker self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 901*523fa7a6SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 902*523fa7a6SAndroid Build Coastguard Worker self.batch_norm = torch.nn.BatchNorm2d(kernel_sz) 903*523fa7a6SAndroid Build Coastguard Worker self.add = torch.add 904*523fa7a6SAndroid Build Coastguard Worker self.mean = torch.mean 905*523fa7a6SAndroid Build Coastguard Worker self.reshape = torch.reshape 906*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 10) 907*523fa7a6SAndroid Build Coastguard Worker self.permute = torch.permute 908*523fa7a6SAndroid Build Coastguard Worker self.eval() 909*523fa7a6SAndroid Build Coastguard Worker 910*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 911*523fa7a6SAndroid Build Coastguard Worker x1 = self.conv1(x) 912*523fa7a6SAndroid Build Coastguard Worker x2 = self.batch_norm(x1) 913*523fa7a6SAndroid Build Coastguard Worker x3 = self.relu(x2) 914*523fa7a6SAndroid Build Coastguard Worker x4 = self.conv2(x3) 915*523fa7a6SAndroid Build Coastguard Worker x5 = self.relu(x4) 916*523fa7a6SAndroid Build Coastguard Worker y1 = self.conv3(y) 917*523fa7a6SAndroid Build Coastguard Worker y2 = self.batch_norm(y1) 918*523fa7a6SAndroid Build Coastguard Worker y3 = self.relu(y2) 919*523fa7a6SAndroid Build Coastguard Worker y4 = self.conv4(y3) 920*523fa7a6SAndroid Build Coastguard Worker y5 = self.relu(y4) 921*523fa7a6SAndroid Build Coastguard Worker z = self.add(x5, y5) 922*523fa7a6SAndroid Build Coastguard Worker z1 = self.permute(z, (0, 3, 2, 1)) 923*523fa7a6SAndroid Build Coastguard Worker z2 = torch.mean(z1, [1, 2], True) 924*523fa7a6SAndroid Build Coastguard Worker z3 = self.reshape(z2, (8, -1)) 925*523fa7a6SAndroid Build Coastguard Worker z4 = self.linear(z3) 926*523fa7a6SAndroid Build Coastguard Worker z5 = self.hardtanh(z4) 927*523fa7a6SAndroid Build Coastguard Worker return z5 928*523fa7a6SAndroid Build Coastguard Worker 929*523fa7a6SAndroid Build Coastguard Worker 930*523fa7a6SAndroid Build Coastguard Workerclass SliceCopy(torch.nn.Module): 931*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 932*523fa7a6SAndroid Build Coastguard Worker super().__init__() 933*523fa7a6SAndroid Build Coastguard Worker self.position_ids = torch.randn([1, 512]) 934*523fa7a6SAndroid Build Coastguard Worker 935*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 936*523fa7a6SAndroid Build Coastguard Worker seq_length = y.size()[1] 937*523fa7a6SAndroid Build Coastguard Worker return x[:, :seq_length] + self.position_ids[:, :seq_length] 938*523fa7a6SAndroid Build Coastguard Worker 939*523fa7a6SAndroid Build Coastguard Worker 940*523fa7a6SAndroid Build Coastguard Workerclass SliceCopyWithStep(torch.nn.Module): 941*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 942*523fa7a6SAndroid Build Coastguard Worker super().__init__() 943*523fa7a6SAndroid Build Coastguard Worker self.position_ids = torch.randn([1, 512]) 944*523fa7a6SAndroid Build Coastguard Worker self.step = 2 945*523fa7a6SAndroid Build Coastguard Worker 946*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 947*523fa7a6SAndroid Build Coastguard Worker seq_length = y.size()[1] 948*523fa7a6SAndroid Build Coastguard Worker return ( 949*523fa7a6SAndroid Build Coastguard Worker x[:, : seq_length : self.step] 950*523fa7a6SAndroid Build Coastguard Worker + self.position_ids[:, : seq_length : self.step] 951*523fa7a6SAndroid Build Coastguard Worker ) 952*523fa7a6SAndroid Build Coastguard Worker 953*523fa7a6SAndroid Build Coastguard Worker 954*523fa7a6SAndroid Build Coastguard Workerclass Softmax(torch.nn.Module): 955*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 956*523fa7a6SAndroid Build Coastguard Worker super().__init__() 957*523fa7a6SAndroid Build Coastguard Worker 958*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 959*523fa7a6SAndroid Build Coastguard Worker return torch.nn.functional.softmax(x, dim=-1) 960*523fa7a6SAndroid Build Coastguard Worker 961*523fa7a6SAndroid Build Coastguard Worker 962*523fa7a6SAndroid Build Coastguard Workerclass Sqrt(torch.nn.Module): 963*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 964*523fa7a6SAndroid Build Coastguard Worker super().__init__() 965*523fa7a6SAndroid Build Coastguard Worker 966*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 967*523fa7a6SAndroid Build Coastguard Worker return torch.sqrt(x) 968*523fa7a6SAndroid Build Coastguard Worker 969*523fa7a6SAndroid Build Coastguard Worker 970*523fa7a6SAndroid Build Coastguard Workerclass SqrtConstant(torch.nn.Module): 971*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 972*523fa7a6SAndroid Build Coastguard Worker super().__init__() 973*523fa7a6SAndroid Build Coastguard Worker 974*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 975*523fa7a6SAndroid Build Coastguard Worker return x / torch.sqrt(torch.tensor([64.0])) 976*523fa7a6SAndroid Build Coastguard Worker 977*523fa7a6SAndroid Build Coastguard Worker 978*523fa7a6SAndroid Build Coastguard Workerclass Squeeze(torch.nn.Module): 979*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 980*523fa7a6SAndroid Build Coastguard Worker super().__init__() 981*523fa7a6SAndroid Build Coastguard Worker 982*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 983*523fa7a6SAndroid Build Coastguard Worker return x.squeeze() 984*523fa7a6SAndroid Build Coastguard Worker 985*523fa7a6SAndroid Build Coastguard Worker 986*523fa7a6SAndroid Build Coastguard Workerclass Stack(torch.nn.Module): 987*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 988*523fa7a6SAndroid Build Coastguard Worker super().__init__() 989*523fa7a6SAndroid Build Coastguard Worker 990*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 991*523fa7a6SAndroid Build Coastguard Worker return torch.stack((x, y)) 992*523fa7a6SAndroid Build Coastguard Worker 993*523fa7a6SAndroid Build Coastguard Worker 994*523fa7a6SAndroid Build Coastguard Workerclass Sub(torch.nn.Module): 995*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 996*523fa7a6SAndroid Build Coastguard Worker super().__init__() 997*523fa7a6SAndroid Build Coastguard Worker 998*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 999*523fa7a6SAndroid Build Coastguard Worker return torch.sub(x, y) 1000*523fa7a6SAndroid Build Coastguard Worker 1001*523fa7a6SAndroid Build Coastguard Worker 1002*523fa7a6SAndroid Build Coastguard Workerclass SubConstantFloat(torch.nn.Module): 1003*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1004*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1005*523fa7a6SAndroid Build Coastguard Worker 1006*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1007*523fa7a6SAndroid Build Coastguard Worker return 10.0 - x 1008*523fa7a6SAndroid Build Coastguard Worker 1009*523fa7a6SAndroid Build Coastguard Worker 1010*523fa7a6SAndroid Build Coastguard Workerclass SubConstantLong(torch.nn.Module): 1011*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1012*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1013*523fa7a6SAndroid Build Coastguard Worker 1014*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1015*523fa7a6SAndroid Build Coastguard Worker return 10 - x 1016*523fa7a6SAndroid Build Coastguard Worker 1017*523fa7a6SAndroid Build Coastguard Worker 1018*523fa7a6SAndroid Build Coastguard Workerclass SumIntList(torch.nn.Module): 1019*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1020*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1021*523fa7a6SAndroid Build Coastguard Worker 1022*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1023*523fa7a6SAndroid Build Coastguard Worker return torch.sum(x, dim=(2, 3), keepdim=True) 1024*523fa7a6SAndroid Build Coastguard Worker 1025*523fa7a6SAndroid Build Coastguard Worker 1026*523fa7a6SAndroid Build Coastguard Workerclass Tanh(torch.nn.Module): 1027*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1028*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1029*523fa7a6SAndroid Build Coastguard Worker 1030*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1031*523fa7a6SAndroid Build Coastguard Worker return torch.tanh(x) 1032*523fa7a6SAndroid Build Coastguard Worker 1033*523fa7a6SAndroid Build Coastguard Worker 1034*523fa7a6SAndroid Build Coastguard Workerclass TopKandIndex(torch.nn.Module): 1035*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1036*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1037*523fa7a6SAndroid Build Coastguard Worker self.idx_source = torch.rand(10, 3) 1038*523fa7a6SAndroid Build Coastguard Worker 1039*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1040*523fa7a6SAndroid Build Coastguard Worker a, b = torch.topk(x, 3) 1041*523fa7a6SAndroid Build Coastguard Worker return a + self.idx_source[b] 1042*523fa7a6SAndroid Build Coastguard Worker 1043*523fa7a6SAndroid Build Coastguard Worker 1044*523fa7a6SAndroid Build Coastguard Workerclass Unbind(torch.nn.Module): 1045*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1046*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1047*523fa7a6SAndroid Build Coastguard Worker 1048*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1049*523fa7a6SAndroid Build Coastguard Worker return torch.unbind(x) 1050*523fa7a6SAndroid Build Coastguard Worker 1051*523fa7a6SAndroid Build Coastguard Worker 1052*523fa7a6SAndroid Build Coastguard Workerclass Unsqueeze(torch.nn.Module): 1053*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1054*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1055*523fa7a6SAndroid Build Coastguard Worker 1056*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 1057*523fa7a6SAndroid Build Coastguard Worker return x.unsqueeze(0) 1058*523fa7a6SAndroid Build Coastguard Worker 1059*523fa7a6SAndroid Build Coastguard Worker 1060*523fa7a6SAndroid Build Coastguard Workerclass View(torch.nn.Module): 1061*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1062*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1063*523fa7a6SAndroid Build Coastguard Worker self.first_size = 2 1064*523fa7a6SAndroid Build Coastguard Worker self.second_size = 256 1065*523fa7a6SAndroid Build Coastguard Worker 1066*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 1067*523fa7a6SAndroid Build Coastguard Worker new_shape = x.size()[:-1] + (self.first_size, self.second_size) 1068*523fa7a6SAndroid Build Coastguard Worker return x.view(new_shape) 1069*523fa7a6SAndroid Build Coastguard Worker 1070*523fa7a6SAndroid Build Coastguard Worker 1071*523fa7a6SAndroid Build Coastguard Workerclass ViewPermuteMatMul(torch.nn.Module): 1072*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 1073*523fa7a6SAndroid Build Coastguard Worker super().__init__() 1074*523fa7a6SAndroid Build Coastguard Worker self.first_size = 2 1075*523fa7a6SAndroid Build Coastguard Worker self.second_size = 256 1076*523fa7a6SAndroid Build Coastguard Worker 1077*523fa7a6SAndroid Build Coastguard Worker def forward(self, x, y): 1078*523fa7a6SAndroid Build Coastguard Worker new_shape = x.size()[:-1] + (self.first_size, self.second_size) 1079*523fa7a6SAndroid Build Coastguard Worker x = x.view(new_shape) 1080*523fa7a6SAndroid Build Coastguard Worker x = x.permute(0, 2, 1, 3) 1081*523fa7a6SAndroid Build Coastguard Worker return torch.matmul(x, y.transpose(-1, -2)) 1082