xref: /aosp_15_r20/external/executorch/backends/qualcomm/tests/models.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8
9
10# module with related operator only
11class Add(torch.nn.Module):
12    def __init__(self):
13        super().__init__()
14
15    def forward(self, x, y):
16        return torch.add(x, y)
17
18
19class AddConstantFloat(torch.nn.Module):
20    def __init__(self):
21        super().__init__()
22
23    def forward(self, x):
24        return 10.0 + x
25
26
27class AddConstantLong(torch.nn.Module):
28    def __init__(self):
29        super().__init__()
30
31    def forward(self, x):
32        return 10 + x
33
34
35class Arange(torch.nn.Module):
36    def __init__(self, x):
37        super().__init__()
38        self.x = x
39
40    def forward(self, y):
41        return torch.arange(self.x, dtype=torch.float32) + y
42
43
44class AvgPoolModule(torch.nn.Module):
45    def __init__(self):
46        super().__init__()
47        self.avgPool = torch.nn.AvgPool2d(
48            kernel_size=(2, 2),
49            padding=(1, 1),
50            stride=(1, 1),
51            count_include_pad=False,
52        )
53
54    def forward(self, x):
55        return self.avgPool(x)
56
57
58class BatchNorm(torch.nn.Module):
59    def __init__(self, n_features):
60        super().__init__()
61        self.native_batchnorm = torch.nn.BatchNorm2d(n_features)
62        self.eval()
63
64    def forward(self, x):
65        return self.native_batchnorm(x)
66
67
68class Bmm(torch.nn.Module):
69    def __init__(self):
70        super().__init__()
71
72    def forward(self, x, y):
73        return torch.matmul(x, y)
74
75
76class Cast(torch.nn.Module):
77    def __init__(self):
78        super().__init__()
79
80    def forward(self, x):
81        return x.type(torch.IntTensor)
82
83
84class Cat2(torch.nn.Module):
85    def __init__(self):
86        super().__init__()
87
88    def forward(self, x, y):
89        return torch.cat((x, y), axis=2)
90
91
92class Cat3(torch.nn.Module):
93    def __init__(self):
94        super().__init__()
95
96    def forward(self, x, y):
97        return torch.concat((y, y, x), axis=2)
98
99
100class Cat4(torch.nn.Module):
101    def __init__(self):
102        super().__init__()
103
104    def forward(self, x, y):
105        return torch.cat((y, y, x, x), axis=2)
106
107
108class Ceil(torch.nn.Module):
109    def __init__(self):
110        super().__init__()
111
112    def forward(self, x):
113        return torch.ceil(x)
114
115
116class Chunk(torch.nn.Module):
117    def __init__(self):
118        super().__init__()
119
120    def forward(self, x):
121        return torch.chunk(x, chunks=2, dim=-1)
122
123
124class ChunkAdd(torch.nn.Module):
125    def __init__(self):
126        super().__init__()
127
128    def forward(self, x):
129        c1, c2 = torch.chunk(x, chunks=2, dim=-1)
130        return torch.add(c1, c2)
131
132
133class Clamp(torch.nn.Module):
134    def __init__(self):
135        super().__init__()
136
137    def forward(self, x):
138        return torch.clamp(x, max=0)
139
140
141class CompositeDelegateModule(torch.nn.Module):
142    def __init__(
143        self,
144        compiler_specs,
145        partitioner_type,
146        capture_method,
147        lowered_method,
148        quantize_method=None,
149    ) -> None:
150        super().__init__()
151        self.modules = [
152            Conv2dSequential(),
153            Conv2dSequential(),
154            Add(),
155            Relu(),
156        ]
157        self.sample_inputs = [
158            (torch.randn([1, 1, 3, 3]),),
159            (torch.randn([1, 1, 3, 3]),),
160            (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])),
161            (torch.randn([1, 2, 3, 3]),),
162        ]
163        self.lowered_modules = []
164        for module, sample_input in zip(self.modules, self.sample_inputs):
165            partitioner = partitioner_type(compiler_specs)
166            if quantize_method:
167                module = quantize_method(module, sample_input)
168            edge_prog = capture_method(module, sample_input)
169            edge_prog.exported_program = lowered_method(
170                edge_prog.exported_program, partitioner
171            )
172            self.lowered_modules.append(
173                edge_prog.exported_program.graph_module._modules.get("lowered_module_0")
174            )
175
176    def forward(self, x, y):
177        x1 = self.lowered_modules[0](x)
178        x2 = self.lowered_modules[1](y)
179        x3 = self.lowered_modules[2](x1[0], x2[0])
180        x4 = self.lowered_modules[3](x3[0])
181        return x4[0]
182
183    def get_random_input(self):
184        return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3]))
185
186    def get_reference_module(self):
187        class CompositeReferenceModule(torch.nn.Module):
188            def __init__(self, modules):
189                super().__init__()
190                self.modules = modules
191
192            def forward(self, x, y):
193                x1 = self.modules[0](x)
194                x2 = self.modules[1](y)
195                x3 = self.modules[2](x1, x2)
196                x4 = self.modules[3](x3)
197                return x4
198
199        return CompositeReferenceModule(self.modules)
200
201
202class ContextBinaryExample(torch.nn.Module):
203    def forward(self, x, y):
204        x = torch.nn.functional.relu(x)
205        y = torch.nn.functional.relu(y)
206        return x, y
207
208    def example_inputs(self):
209        return {
210            "x": torch.randn((1, 3, 3, 3)),
211            "y": torch.randn((2, 1, 5, 5)),
212        }
213
214
215class Conv1dSequential(torch.nn.Module):
216    def __init__(self, bias=True):
217        super().__init__()
218        self.first = torch.nn.Conv1d(
219            in_channels=1,
220            out_channels=3,
221            kernel_size=(3),
222            padding=1,
223            bias=bias,
224        )
225
226        self.second = torch.nn.Conv1d(
227            in_channels=3,
228            out_channels=2,
229            kernel_size=(3),
230            padding=1,
231            bias=bias,
232        )
233
234    def forward(self, x):
235        return self.second(self.first(x))
236
237
238# small models
239class Conv1dReluLogSoftmax(torch.nn.Module):
240    def __init__(self):
241        super().__init__()
242        self.conv = torch.nn.Conv1d(
243            in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1
244        )
245        self.logsoftmax = torch.nn.LogSoftmax(dim=1)
246
247    def forward(self, x):
248        x = torch.nn.functional.relu(self.conv(x))
249        x = self.logsoftmax(x)
250        return x
251
252
253class Conv2dAvgPool2d(torch.nn.Module):
254    def __init__(self):
255        super().__init__()
256        self.conv = torch.nn.Conv2d(
257            3, 16, 7, bias=True, stride=2, padding=3, dilation=1
258        )
259        self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1)
260
261    def forward(self, x):
262        return self.pool(self.conv(x))
263
264
265class Conv2dBnHardtanhMean(torch.nn.Module):
266    def __init__(self):
267        super(Conv2dBnHardtanhMean, self).__init__()
268        groups = 1
269        stride = [2, 2]
270        padding = [1, 1]
271        dilation = [1, 1]
272        in_channels = 1
273        out_channels = 1
274
275        self.conv = torch.nn.Conv2d(
276            in_channels=in_channels,
277            out_channels=out_channels,
278            kernel_size=(3, 3),
279            stride=stride,
280            padding=padding,
281            groups=groups,
282            dilation=dilation,
283            bias=True,
284        )
285        self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size()))
286        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
287        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
288        self.eval()
289
290    def forward(self, x):
291        x1 = self.conv(x)
292        x2 = self.native_batchnorm(x1)
293        x3 = self.hardtanh(x2)
294        x4 = torch.mean(x3, (1), keepdim=True)
295        return x4
296
297
298class Conv2dCat(torch.nn.Module):
299    def __init__(self):
300        super().__init__()
301        self.conv1 = torch.nn.Conv2d(3, 3, 3)
302        self.conv2 = torch.nn.Conv2d(3, 3, 3)
303
304    def forward(self, x, y):
305        x = self.conv1(x)
306        y = self.conv2(y)
307        z = torch.cat([x, y], dim=1)
308        return z
309
310
311class Conv2dMaxPool2d(torch.nn.Module):
312    def __init__(self):
313        super().__init__()
314        self.conv = torch.nn.Conv2d(
315            in_channels=2,
316            out_channels=2,
317            kernel_size=(1, 1),
318            padding=1,
319            bias=True,
320        )
321        self.pool = torch.nn.MaxPool2d(1, 1)
322
323    def forward(self, x):
324        return self.pool(self.conv(x))
325
326
327class Conv2dSequential(torch.nn.Module):
328    def __init__(self, bias=True):
329        super().__init__()
330        self.first = torch.nn.Conv2d(
331            in_channels=1,
332            out_channels=3,
333            kernel_size=(3, 3),
334            padding=1,
335            bias=bias,
336        )
337        self.second = torch.nn.Conv2d(
338            in_channels=3,
339            out_channels=2,
340            kernel_size=(3, 3),
341            padding=1,
342            bias=bias,
343        )
344
345    def forward(self, x):
346        return self.second(self.first(x))
347
348
349class Conv2dSingle(torch.nn.Module):
350    def __init__(self, bias=True):
351        super().__init__()
352        self.conv = torch.nn.Conv2d(
353            in_channels=1,
354            out_channels=3,
355            kernel_size=(3, 3),
356            padding=1,
357            bias=bias,
358        )
359
360    def forward(self, x):
361        return self.conv(x)
362
363
364class ConvTranspose2dSingle(torch.nn.Module):
365    def __init__(self, bias=True):
366        super().__init__()
367        self.conv_transpose = torch.nn.ConvTranspose2d(
368            in_channels=1,
369            out_channels=3,
370            kernel_size=3,
371            stride=2,
372            padding=1,
373            bias=bias,
374        )
375
376    def forward(self, x):
377        return self.conv_transpose(x)
378
379
380class Conv2dDownUpSample(torch.nn.Module):
381    def __init__(self, bias=True):
382        super().__init__()
383        self.conv = torch.nn.Conv2d(
384            in_channels=16,
385            out_channels=16,
386            kernel_size=3,
387            stride=2,
388            padding=1,
389            bias=bias,
390        )
391        self.conv_transpose = torch.nn.ConvTranspose2d(
392            in_channels=16,
393            out_channels=16,
394            kernel_size=3,
395            stride=2,
396            padding=1,
397            bias=bias,
398        )
399
400    def forward(self, x):
401        return self.conv_transpose(self.conv(x))
402
403
404class Conv2dSumReduceDim(torch.nn.Module):
405    def __init__(self):
406        super().__init__()
407        self.first = torch.nn.Conv2d(
408            in_channels=1,
409            out_channels=3,
410            kernel_size=(3, 3),
411            padding=1,
412            bias=True,
413        )
414
415    def forward(self, x):
416        return torch.sum(self.first(x), dim=(2, 3), keepdim=False)
417
418
419class Conv2dTopK(torch.nn.Module):
420    def __init__(self):
421        super().__init__()
422        self.conv = torch.nn.Conv2d(3, 16, 3)
423
424    def forward(self, x):
425        x = self.conv(x)
426        topk_values, topk_indices = torch.topk(x, 5, dim=1)
427        return topk_values
428
429
430class Div(torch.nn.Module):
431    def __init__(self):
432        super().__init__()
433
434    def forward(self, x, y):
435        return torch.divide(x, y)
436
437
438class DivConstantFloat(torch.nn.Module):
439    def __init__(self):
440        super().__init__()
441
442    def forward(self, x):
443        return x / 10.0
444
445
446class DivConstantLong(torch.nn.Module):
447    def __init__(self):
448        super().__init__()
449
450    def forward(self, x):
451        return x / 10
452
453
454class EinsumBilinear(torch.nn.Module):
455    def __init__(self):
456        super().__init__()
457
458    def forward(self, bn, anm, bm):
459        return torch.einsum("bn,anm,bm->ba", bn, anm, bm)
460
461
462class EinsumOuterProduct(torch.nn.Module):
463    def __init__(self):
464        super().__init__()
465
466    def forward(self, i, j):
467        return torch.einsum("i,j->ij", i, j)
468
469
470class EinsumOuterProductRelu(torch.nn.Module):
471    def __init__(self):
472        super().__init__()
473
474    def forward(self, i, j):
475        return torch.relu(torch.einsum("i,j->ij", i, j))
476
477
478class Embedding(torch.nn.Module):
479    def __init__(self):
480        super().__init__()
481        self.embedding = torch.nn.Embedding(10, 3)
482
483    def forward(self, x):
484        return self.embedding(x)
485
486
487class ExpandCopy(torch.nn.Module):
488    def __init__(self):
489        super().__init__()
490
491    def forward(self, x):
492        return x.expand(3, 4)
493
494
495class Gelu(torch.nn.Module):
496    def __init__(self):
497        super().__init__()
498        self.gelu = torch.nn.GELU()
499
500    def forward(self, x):
501        return self.gelu(x)
502
503
504class GroupNorm(torch.nn.Module):
505    def __init__(self, bias=True):
506        super().__init__()
507        self.conv = torch.nn.Conv2d(
508            32,
509            256,
510            kernel_size=3,
511            stride=1,
512            padding=1,
513            bias=bias,
514        )
515        self.norm = torch.nn.GroupNorm(32, 256)
516
517    def forward(self, x):
518        y = self.conv(x)
519        return y, self.norm(y)
520
521
522class HardSigmoid(torch.nn.Module):
523    def __init__(self):
524        super().__init__()
525        self.hardsigmoid = torch.nn.Hardsigmoid()
526
527    def forward(self, x):
528        return self.hardsigmoid(x)
529
530
531class HardSwish(torch.nn.Module):
532    def __init__(self):
533        super().__init__()
534        self.hardswish = torch.nn.Hardswish()
535
536    def forward(self, x):
537        return self.hardswish(x)
538
539
540class HardTanh(torch.nn.Module):
541    def __init__(self):
542        super().__init__()
543        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
544
545    def forward(self, x):
546        return self.hardtanh(x)
547
548
549class Index(torch.nn.Module):
550    def __init__(self):
551        super().__init__()
552        self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32)
553        self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
554
555    def forward(self, x):
556        return x[self.idx0] + x[self.idx1]
557
558
559class IndexPut(torch.nn.Module):
560    def __init__(self):
561        super().__init__()
562        self.register_buffer(
563            "k_cache",
564            torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
565        )
566
567    def forward(self, input_pos, k_val):
568        k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
569        return k_out
570
571
572class LayerNorm(torch.nn.Module):
573    def __init__(self):
574        super().__init__()
575        self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
576        self.linear = torch.nn.Linear(768, 196)
577
578    def forward(self, x):
579        return self.linear(self.layer_norm(x))
580
581
582class LeakyReLUDefault(torch.nn.Module):
583    def __init__(self):
584        super().__init__()
585        self.leaky_relu = torch.nn.LeakyReLU()
586
587    def forward(self, x):
588        return self.leaky_relu(x)
589
590
591class LeakyReLUCustom(torch.nn.Module):
592    def __init__(self, coeff):
593        super().__init__()
594        self.leaky_relu = torch.nn.LeakyReLU(coeff)
595
596    def forward(self, x):
597        return self.leaky_relu(x)
598
599
600class Linear(torch.nn.Module):
601    def __init__(self, use_bias: bool = True):
602        super().__init__()
603        self.linear = torch.nn.Linear(4, 5, use_bias).eval()
604
605    def forward(self, x):
606        return self.linear(x)
607
608
609class LogSoftmax(torch.nn.Module):
610    def __init__(self):
611        super().__init__()
612
613    def forward(self, x):
614        return torch.nn.functional.log_softmax(x, dim=-1)
615
616
617class MaxPool2d(torch.nn.Module):
618    def __init__(self):
619        super().__init__()
620        self.max_pool2d = torch.nn.MaxPool2d(
621            kernel_size=3,
622            stride=1,
623            padding=1,
624            dilation=1,
625            ceil_mode=True,
626        )
627
628    def forward(self, x):
629        return self.max_pool2d(x)
630
631
632class MeanWKeppDim(torch.nn.Module):
633    def __init__(self):
634        super().__init__()
635
636    def forward(self, x):
637        return torch.mean(x, (-1, -2), keepdim=True)
638
639
640class MeanWOKeppDim(torch.nn.Module):
641    def __init__(self):
642        super().__init__()
643
644    def forward(self, x):
645        return torch.mean(x, (-1, -2))
646
647
648class Mul(torch.nn.Module):
649    def __init__(self):
650        super().__init__()
651
652    def forward(self, x, y):
653        return torch.mul(x, y)
654
655
656class MulConstantFloat(torch.nn.Module):
657    def __init__(self):
658        super().__init__()
659
660    def forward(self, x):
661        return 10.0 * x
662
663
664class MulConstantLong(torch.nn.Module):
665    def __init__(self):
666        super().__init__()
667
668    def forward(self, x):
669        return 10 * x
670
671
672class MulScalar(torch.nn.Module):
673    def __init__(self):
674        super().__init__()
675        self._scalar = 3.14
676
677    def forward(self, x):
678        out1 = torch.ops.aten.mul.Scalar(x, self._scalar)
679        return out1
680
681
682class MultiheadAttention(torch.nn.Module):
683    def __init__(self):
684        super().__init__()
685        self.multi_head_attention = torch.nn.MultiheadAttention(
686            96, 12, dropout=0.0, batch_first=True
687        )
688
689    def forward(self, x):
690        attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False)
691        return attn_output
692
693
694class Pad(torch.nn.Module):
695    def __init__(self):
696        super().__init__()
697
698    def forward(self, x):
699        return torch.nn.functional.pad(
700            x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant"
701        )
702
703
704class PixelShuffle(torch.nn.Module):
705    def __init__(self, scale):
706        super().__init__()
707        self.pixel_shuffle = torch.nn.PixelShuffle(scale)
708
709    def forward(self, x):
710        return self.pixel_shuffle(x)
711
712
713class PixelUnshuffle(torch.nn.Module):
714    def __init__(self, scale):
715        super().__init__()
716        self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale)
717
718    def forward(self, x):
719        return self.pixel_unshuffle(x)
720
721
722class PixelUnshuffleMathEquivalent(torch.nn.Module):
723    def __init__(self, scale):
724        super().__init__()
725        self.scale = scale
726
727    def forward(self, x):
728        b, c, hh, hw = x.size()
729        out_channel = c * (self.scale**2)
730        h = hh // self.scale
731        w = hw // self.scale
732        x_view = x.view(b, c, h, self.scale, w, self.scale)
733        return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
734
735
736class PowTensorScalar(torch.nn.Module):
737    def __init__(self):
738        super().__init__()
739
740    def forward(self, x):
741        return torch.pow(x, 2)
742
743
744class PReLUDefault(torch.nn.Module):
745    def __init__(self):
746        super().__init__()
747        self.prelu = torch.nn.PReLU()
748
749    def forward(self, x):
750        return self.prelu(x)
751
752
753class PReLUPerChannel(torch.nn.Module):
754    def __init__(self, channels):
755        super().__init__()
756        self.prelu = torch.nn.PReLU(channels)
757
758    def forward(self, x):
759        return self.prelu(x)
760
761
762class Relu(torch.nn.Module):
763    def __init__(self):
764        super().__init__()
765        self.relu = torch.nn.ReLU()
766
767    def forward(self, x):
768        return self.relu(x)
769
770
771class Reshape(torch.nn.Module):
772    def __init__(self):
773        super().__init__()
774
775    def forward(self, x):
776        return x.reshape(1, 12)
777
778
779class ResidualBlockModule(torch.nn.Module):
780    def __init__(self):
781        super(ResidualBlockModule, self).__init__()
782        groups = 1
783        stride = [1, 1]
784        padding = [1, 1]
785        dilation = [1, 1]
786        in_channels = 32
787        out_channels = 32
788
789        self.conv = torch.nn.Conv2d(
790            in_channels=in_channels,
791            out_channels=out_channels,
792            kernel_size=(3, 3),
793            stride=stride,
794            padding=padding,
795            groups=groups,
796            dilation=dilation,
797            bias=True,
798        )
799        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
800        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0)
801        self.eval()
802
803    def forward(self, x):
804        x1 = self.conv(x)
805        x2 = self.native_batchnorm(x1)
806        x3 = self.conv(x2)
807        x4 = self.native_batchnorm(x3)
808        x5 = self.hardtanh(x4)
809        x6 = torch.add(x5, x2)
810        return x6
811
812
813class ResizeBilinear2D(torch.nn.Module):
814    def __init__(self):
815        super().__init__()
816
817    def forward(self, x):
818        output_shape = [dim * 2 for dim in x.shape[-2:]]
819        return torch.nn.functional.interpolate(
820            x,
821            size=list(torch.randn(output_shape).shape),
822            mode="bilinear",
823            align_corners=False,
824        )
825
826
827class ResizeNearest2D(torch.nn.Module):
828    def __init__(self):
829        super().__init__()
830
831    def forward(self, x):
832        output_shape = [dim * 2 for dim in x.shape[-2:]]
833        return torch.nn.functional.interpolate(
834            x,
835            size=list(torch.randn(output_shape).shape),
836            mode="nearest",
837        )
838
839
840class RmsNorm(torch.nn.Module):
841    def __init__(self):
842        super().__init__()
843        self.eps = 1e-5
844        self.rms = torch.nn.RMSNorm([4], 1e-5)
845
846    def forward(self, x):
847        return self.rms(x)
848
849
850class Rsqrt(torch.nn.Module):
851    def __init__(self):
852        super().__init__()
853
854    def forward(self, x):
855        return torch.rsqrt(x)
856
857
858class ScaledDotProductAttention(torch.nn.Module):
859    def __init__(self):
860        super().__init__()
861
862    def forward(self, query_layer, key_layer, value_layer, attn_mask):
863        attn_output = torch.nn.functional.scaled_dot_product_attention(
864            query_layer, key_layer, value_layer, attn_mask
865        )
866        return attn_output
867
868
869class SelectCopy(torch.nn.Module):
870    def __init__(self):
871        super().__init__()
872        self.conv = torch.nn.Conv2d(
873            in_channels=3,
874            out_channels=2,
875            kernel_size=(3, 3),
876            padding=1,
877            bias=True,
878        )
879
880    def forward(self, x):
881        return self.conv(x)[0, 1, 1:2]
882
883
884class Sigmoid(torch.nn.Module):
885    def __init__(self):
886        super().__init__()
887
888    def forward(self, x):
889        return torch.sigmoid(x)
890
891
892class SimpleModel(torch.nn.Module):
893    def __init__(self):
894        super().__init__()
895        kernel_sz = 32
896        self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
897        self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
898        self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
899        self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
900        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
901        self.relu = torch.nn.ReLU()
902        self.batch_norm = torch.nn.BatchNorm2d(kernel_sz)
903        self.add = torch.add
904        self.mean = torch.mean
905        self.reshape = torch.reshape
906        self.linear = torch.nn.Linear(4, 10)
907        self.permute = torch.permute
908        self.eval()
909
910    def forward(self, x, y):
911        x1 = self.conv1(x)
912        x2 = self.batch_norm(x1)
913        x3 = self.relu(x2)
914        x4 = self.conv2(x3)
915        x5 = self.relu(x4)
916        y1 = self.conv3(y)
917        y2 = self.batch_norm(y1)
918        y3 = self.relu(y2)
919        y4 = self.conv4(y3)
920        y5 = self.relu(y4)
921        z = self.add(x5, y5)
922        z1 = self.permute(z, (0, 3, 2, 1))
923        z2 = torch.mean(z1, [1, 2], True)
924        z3 = self.reshape(z2, (8, -1))
925        z4 = self.linear(z3)
926        z5 = self.hardtanh(z4)
927        return z5
928
929
930class SliceCopy(torch.nn.Module):
931    def __init__(self):
932        super().__init__()
933        self.position_ids = torch.randn([1, 512])
934
935    def forward(self, x, y):
936        seq_length = y.size()[1]
937        return x[:, :seq_length] + self.position_ids[:, :seq_length]
938
939
940class SliceCopyWithStep(torch.nn.Module):
941    def __init__(self):
942        super().__init__()
943        self.position_ids = torch.randn([1, 512])
944        self.step = 2
945
946    def forward(self, x, y):
947        seq_length = y.size()[1]
948        return (
949            x[:, : seq_length : self.step]
950            + self.position_ids[:, : seq_length : self.step]
951        )
952
953
954class Softmax(torch.nn.Module):
955    def __init__(self):
956        super().__init__()
957
958    def forward(self, x):
959        return torch.nn.functional.softmax(x, dim=-1)
960
961
962class Sqrt(torch.nn.Module):
963    def __init__(self):
964        super().__init__()
965
966    def forward(self, x):
967        return torch.sqrt(x)
968
969
970class SqrtConstant(torch.nn.Module):
971    def __init__(self):
972        super().__init__()
973
974    def forward(self, x):
975        return x / torch.sqrt(torch.tensor([64.0]))
976
977
978class Squeeze(torch.nn.Module):
979    def __init__(self):
980        super().__init__()
981
982    def forward(self, x):
983        return x.squeeze()
984
985
986class Stack(torch.nn.Module):
987    def __init__(self):
988        super().__init__()
989
990    def forward(self, x, y):
991        return torch.stack((x, y))
992
993
994class Sub(torch.nn.Module):
995    def __init__(self):
996        super().__init__()
997
998    def forward(self, x, y):
999        return torch.sub(x, y)
1000
1001
1002class SubConstantFloat(torch.nn.Module):
1003    def __init__(self):
1004        super().__init__()
1005
1006    def forward(self, x):
1007        return 10.0 - x
1008
1009
1010class SubConstantLong(torch.nn.Module):
1011    def __init__(self):
1012        super().__init__()
1013
1014    def forward(self, x):
1015        return 10 - x
1016
1017
1018class SumIntList(torch.nn.Module):
1019    def __init__(self):
1020        super().__init__()
1021
1022    def forward(self, x):
1023        return torch.sum(x, dim=(2, 3), keepdim=True)
1024
1025
1026class Tanh(torch.nn.Module):
1027    def __init__(self):
1028        super().__init__()
1029
1030    def forward(self, x):
1031        return torch.tanh(x)
1032
1033
1034class TopKandIndex(torch.nn.Module):
1035    def __init__(self):
1036        super().__init__()
1037        self.idx_source = torch.rand(10, 3)
1038
1039    def forward(self, x):
1040        a, b = torch.topk(x, 3)
1041        return a + self.idx_source[b]
1042
1043
1044class Unbind(torch.nn.Module):
1045    def __init__(self):
1046        super().__init__()
1047
1048    def forward(self, x):
1049        return torch.unbind(x)
1050
1051
1052class Unsqueeze(torch.nn.Module):
1053    def __init__(self):
1054        super().__init__()
1055
1056    def forward(self, x):
1057        return x.unsqueeze(0)
1058
1059
1060class View(torch.nn.Module):
1061    def __init__(self):
1062        super().__init__()
1063        self.first_size = 2
1064        self.second_size = 256
1065
1066    def forward(self, x, y):
1067        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
1068        return x.view(new_shape)
1069
1070
1071class ViewPermuteMatMul(torch.nn.Module):
1072    def __init__(self):
1073        super().__init__()
1074        self.first_size = 2
1075        self.second_size = 256
1076
1077    def forward(self, x, y):
1078        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
1079        x = x.view(new_shape)
1080        x = x.permute(0, 2, 1, 3)
1081        return torch.matmul(x, y.transpose(-1, -2))
1082