1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8 9 10# module with related operator only 11class Add(torch.nn.Module): 12 def __init__(self): 13 super().__init__() 14 15 def forward(self, x, y): 16 return torch.add(x, y) 17 18 19class AddConstantFloat(torch.nn.Module): 20 def __init__(self): 21 super().__init__() 22 23 def forward(self, x): 24 return 10.0 + x 25 26 27class AddConstantLong(torch.nn.Module): 28 def __init__(self): 29 super().__init__() 30 31 def forward(self, x): 32 return 10 + x 33 34 35class Arange(torch.nn.Module): 36 def __init__(self, x): 37 super().__init__() 38 self.x = x 39 40 def forward(self, y): 41 return torch.arange(self.x, dtype=torch.float32) + y 42 43 44class AvgPoolModule(torch.nn.Module): 45 def __init__(self): 46 super().__init__() 47 self.avgPool = torch.nn.AvgPool2d( 48 kernel_size=(2, 2), 49 padding=(1, 1), 50 stride=(1, 1), 51 count_include_pad=False, 52 ) 53 54 def forward(self, x): 55 return self.avgPool(x) 56 57 58class BatchNorm(torch.nn.Module): 59 def __init__(self, n_features): 60 super().__init__() 61 self.native_batchnorm = torch.nn.BatchNorm2d(n_features) 62 self.eval() 63 64 def forward(self, x): 65 return self.native_batchnorm(x) 66 67 68class Bmm(torch.nn.Module): 69 def __init__(self): 70 super().__init__() 71 72 def forward(self, x, y): 73 return torch.matmul(x, y) 74 75 76class Cast(torch.nn.Module): 77 def __init__(self): 78 super().__init__() 79 80 def forward(self, x): 81 return x.type(torch.IntTensor) 82 83 84class Cat2(torch.nn.Module): 85 def __init__(self): 86 super().__init__() 87 88 def forward(self, x, y): 89 return torch.cat((x, y), axis=2) 90 91 92class Cat3(torch.nn.Module): 93 def __init__(self): 94 super().__init__() 95 96 def forward(self, x, y): 97 return torch.concat((y, y, x), axis=2) 98 99 100class Cat4(torch.nn.Module): 101 def __init__(self): 102 super().__init__() 103 104 def forward(self, x, y): 105 return torch.cat((y, y, x, x), axis=2) 106 107 108class Ceil(torch.nn.Module): 109 def __init__(self): 110 super().__init__() 111 112 def forward(self, x): 113 return torch.ceil(x) 114 115 116class Chunk(torch.nn.Module): 117 def __init__(self): 118 super().__init__() 119 120 def forward(self, x): 121 return torch.chunk(x, chunks=2, dim=-1) 122 123 124class ChunkAdd(torch.nn.Module): 125 def __init__(self): 126 super().__init__() 127 128 def forward(self, x): 129 c1, c2 = torch.chunk(x, chunks=2, dim=-1) 130 return torch.add(c1, c2) 131 132 133class Clamp(torch.nn.Module): 134 def __init__(self): 135 super().__init__() 136 137 def forward(self, x): 138 return torch.clamp(x, max=0) 139 140 141class CompositeDelegateModule(torch.nn.Module): 142 def __init__( 143 self, 144 compiler_specs, 145 partitioner_type, 146 capture_method, 147 lowered_method, 148 quantize_method=None, 149 ) -> None: 150 super().__init__() 151 self.modules = [ 152 Conv2dSequential(), 153 Conv2dSequential(), 154 Add(), 155 Relu(), 156 ] 157 self.sample_inputs = [ 158 (torch.randn([1, 1, 3, 3]),), 159 (torch.randn([1, 1, 3, 3]),), 160 (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])), 161 (torch.randn([1, 2, 3, 3]),), 162 ] 163 self.lowered_modules = [] 164 for module, sample_input in zip(self.modules, self.sample_inputs): 165 partitioner = partitioner_type(compiler_specs) 166 if quantize_method: 167 module = quantize_method(module, sample_input) 168 edge_prog = capture_method(module, sample_input) 169 edge_prog.exported_program = lowered_method( 170 edge_prog.exported_program, partitioner 171 ) 172 self.lowered_modules.append( 173 edge_prog.exported_program.graph_module._modules.get("lowered_module_0") 174 ) 175 176 def forward(self, x, y): 177 x1 = self.lowered_modules[0](x) 178 x2 = self.lowered_modules[1](y) 179 x3 = self.lowered_modules[2](x1[0], x2[0]) 180 x4 = self.lowered_modules[3](x3[0]) 181 return x4[0] 182 183 def get_random_input(self): 184 return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3])) 185 186 def get_reference_module(self): 187 class CompositeReferenceModule(torch.nn.Module): 188 def __init__(self, modules): 189 super().__init__() 190 self.modules = modules 191 192 def forward(self, x, y): 193 x1 = self.modules[0](x) 194 x2 = self.modules[1](y) 195 x3 = self.modules[2](x1, x2) 196 x4 = self.modules[3](x3) 197 return x4 198 199 return CompositeReferenceModule(self.modules) 200 201 202class ContextBinaryExample(torch.nn.Module): 203 def forward(self, x, y): 204 x = torch.nn.functional.relu(x) 205 y = torch.nn.functional.relu(y) 206 return x, y 207 208 def example_inputs(self): 209 return { 210 "x": torch.randn((1, 3, 3, 3)), 211 "y": torch.randn((2, 1, 5, 5)), 212 } 213 214 215class Conv1dSequential(torch.nn.Module): 216 def __init__(self, bias=True): 217 super().__init__() 218 self.first = torch.nn.Conv1d( 219 in_channels=1, 220 out_channels=3, 221 kernel_size=(3), 222 padding=1, 223 bias=bias, 224 ) 225 226 self.second = torch.nn.Conv1d( 227 in_channels=3, 228 out_channels=2, 229 kernel_size=(3), 230 padding=1, 231 bias=bias, 232 ) 233 234 def forward(self, x): 235 return self.second(self.first(x)) 236 237 238# small models 239class Conv1dReluLogSoftmax(torch.nn.Module): 240 def __init__(self): 241 super().__init__() 242 self.conv = torch.nn.Conv1d( 243 in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1 244 ) 245 self.logsoftmax = torch.nn.LogSoftmax(dim=1) 246 247 def forward(self, x): 248 x = torch.nn.functional.relu(self.conv(x)) 249 x = self.logsoftmax(x) 250 return x 251 252 253class Conv2dAvgPool2d(torch.nn.Module): 254 def __init__(self): 255 super().__init__() 256 self.conv = torch.nn.Conv2d( 257 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 258 ) 259 self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1) 260 261 def forward(self, x): 262 return self.pool(self.conv(x)) 263 264 265class Conv2dBnHardtanhMean(torch.nn.Module): 266 def __init__(self): 267 super(Conv2dBnHardtanhMean, self).__init__() 268 groups = 1 269 stride = [2, 2] 270 padding = [1, 1] 271 dilation = [1, 1] 272 in_channels = 1 273 out_channels = 1 274 275 self.conv = torch.nn.Conv2d( 276 in_channels=in_channels, 277 out_channels=out_channels, 278 kernel_size=(3, 3), 279 stride=stride, 280 padding=padding, 281 groups=groups, 282 dilation=dilation, 283 bias=True, 284 ) 285 self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size())) 286 self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 287 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 288 self.eval() 289 290 def forward(self, x): 291 x1 = self.conv(x) 292 x2 = self.native_batchnorm(x1) 293 x3 = self.hardtanh(x2) 294 x4 = torch.mean(x3, (1), keepdim=True) 295 return x4 296 297 298class Conv2dCat(torch.nn.Module): 299 def __init__(self): 300 super().__init__() 301 self.conv1 = torch.nn.Conv2d(3, 3, 3) 302 self.conv2 = torch.nn.Conv2d(3, 3, 3) 303 304 def forward(self, x, y): 305 x = self.conv1(x) 306 y = self.conv2(y) 307 z = torch.cat([x, y], dim=1) 308 return z 309 310 311class Conv2dMaxPool2d(torch.nn.Module): 312 def __init__(self): 313 super().__init__() 314 self.conv = torch.nn.Conv2d( 315 in_channels=2, 316 out_channels=2, 317 kernel_size=(1, 1), 318 padding=1, 319 bias=True, 320 ) 321 self.pool = torch.nn.MaxPool2d(1, 1) 322 323 def forward(self, x): 324 return self.pool(self.conv(x)) 325 326 327class Conv2dSequential(torch.nn.Module): 328 def __init__(self, bias=True): 329 super().__init__() 330 self.first = torch.nn.Conv2d( 331 in_channels=1, 332 out_channels=3, 333 kernel_size=(3, 3), 334 padding=1, 335 bias=bias, 336 ) 337 self.second = torch.nn.Conv2d( 338 in_channels=3, 339 out_channels=2, 340 kernel_size=(3, 3), 341 padding=1, 342 bias=bias, 343 ) 344 345 def forward(self, x): 346 return self.second(self.first(x)) 347 348 349class Conv2dSingle(torch.nn.Module): 350 def __init__(self, bias=True): 351 super().__init__() 352 self.conv = torch.nn.Conv2d( 353 in_channels=1, 354 out_channels=3, 355 kernel_size=(3, 3), 356 padding=1, 357 bias=bias, 358 ) 359 360 def forward(self, x): 361 return self.conv(x) 362 363 364class ConvTranspose2dSingle(torch.nn.Module): 365 def __init__(self, bias=True): 366 super().__init__() 367 self.conv_transpose = torch.nn.ConvTranspose2d( 368 in_channels=1, 369 out_channels=3, 370 kernel_size=3, 371 stride=2, 372 padding=1, 373 bias=bias, 374 ) 375 376 def forward(self, x): 377 return self.conv_transpose(x) 378 379 380class Conv2dDownUpSample(torch.nn.Module): 381 def __init__(self, bias=True): 382 super().__init__() 383 self.conv = torch.nn.Conv2d( 384 in_channels=16, 385 out_channels=16, 386 kernel_size=3, 387 stride=2, 388 padding=1, 389 bias=bias, 390 ) 391 self.conv_transpose = torch.nn.ConvTranspose2d( 392 in_channels=16, 393 out_channels=16, 394 kernel_size=3, 395 stride=2, 396 padding=1, 397 bias=bias, 398 ) 399 400 def forward(self, x): 401 return self.conv_transpose(self.conv(x)) 402 403 404class Conv2dSumReduceDim(torch.nn.Module): 405 def __init__(self): 406 super().__init__() 407 self.first = torch.nn.Conv2d( 408 in_channels=1, 409 out_channels=3, 410 kernel_size=(3, 3), 411 padding=1, 412 bias=True, 413 ) 414 415 def forward(self, x): 416 return torch.sum(self.first(x), dim=(2, 3), keepdim=False) 417 418 419class Conv2dTopK(torch.nn.Module): 420 def __init__(self): 421 super().__init__() 422 self.conv = torch.nn.Conv2d(3, 16, 3) 423 424 def forward(self, x): 425 x = self.conv(x) 426 topk_values, topk_indices = torch.topk(x, 5, dim=1) 427 return topk_values 428 429 430class Div(torch.nn.Module): 431 def __init__(self): 432 super().__init__() 433 434 def forward(self, x, y): 435 return torch.divide(x, y) 436 437 438class DivConstantFloat(torch.nn.Module): 439 def __init__(self): 440 super().__init__() 441 442 def forward(self, x): 443 return x / 10.0 444 445 446class DivConstantLong(torch.nn.Module): 447 def __init__(self): 448 super().__init__() 449 450 def forward(self, x): 451 return x / 10 452 453 454class EinsumBilinear(torch.nn.Module): 455 def __init__(self): 456 super().__init__() 457 458 def forward(self, bn, anm, bm): 459 return torch.einsum("bn,anm,bm->ba", bn, anm, bm) 460 461 462class EinsumOuterProduct(torch.nn.Module): 463 def __init__(self): 464 super().__init__() 465 466 def forward(self, i, j): 467 return torch.einsum("i,j->ij", i, j) 468 469 470class EinsumOuterProductRelu(torch.nn.Module): 471 def __init__(self): 472 super().__init__() 473 474 def forward(self, i, j): 475 return torch.relu(torch.einsum("i,j->ij", i, j)) 476 477 478class Embedding(torch.nn.Module): 479 def __init__(self): 480 super().__init__() 481 self.embedding = torch.nn.Embedding(10, 3) 482 483 def forward(self, x): 484 return self.embedding(x) 485 486 487class ExpandCopy(torch.nn.Module): 488 def __init__(self): 489 super().__init__() 490 491 def forward(self, x): 492 return x.expand(3, 4) 493 494 495class Gelu(torch.nn.Module): 496 def __init__(self): 497 super().__init__() 498 self.gelu = torch.nn.GELU() 499 500 def forward(self, x): 501 return self.gelu(x) 502 503 504class GroupNorm(torch.nn.Module): 505 def __init__(self, bias=True): 506 super().__init__() 507 self.conv = torch.nn.Conv2d( 508 32, 509 256, 510 kernel_size=3, 511 stride=1, 512 padding=1, 513 bias=bias, 514 ) 515 self.norm = torch.nn.GroupNorm(32, 256) 516 517 def forward(self, x): 518 y = self.conv(x) 519 return y, self.norm(y) 520 521 522class HardSigmoid(torch.nn.Module): 523 def __init__(self): 524 super().__init__() 525 self.hardsigmoid = torch.nn.Hardsigmoid() 526 527 def forward(self, x): 528 return self.hardsigmoid(x) 529 530 531class HardSwish(torch.nn.Module): 532 def __init__(self): 533 super().__init__() 534 self.hardswish = torch.nn.Hardswish() 535 536 def forward(self, x): 537 return self.hardswish(x) 538 539 540class HardTanh(torch.nn.Module): 541 def __init__(self): 542 super().__init__() 543 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 544 545 def forward(self, x): 546 return self.hardtanh(x) 547 548 549class Index(torch.nn.Module): 550 def __init__(self): 551 super().__init__() 552 self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32) 553 self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) 554 555 def forward(self, x): 556 return x[self.idx0] + x[self.idx1] 557 558 559class IndexPut(torch.nn.Module): 560 def __init__(self): 561 super().__init__() 562 self.register_buffer( 563 "k_cache", 564 torch.zeros((1, 1024, 12, 64), dtype=torch.float32), 565 ) 566 567 def forward(self, input_pos, k_val): 568 k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) 569 return k_out 570 571 572class LayerNorm(torch.nn.Module): 573 def __init__(self): 574 super().__init__() 575 self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) 576 self.linear = torch.nn.Linear(768, 196) 577 578 def forward(self, x): 579 return self.linear(self.layer_norm(x)) 580 581 582class LeakyReLUDefault(torch.nn.Module): 583 def __init__(self): 584 super().__init__() 585 self.leaky_relu = torch.nn.LeakyReLU() 586 587 def forward(self, x): 588 return self.leaky_relu(x) 589 590 591class LeakyReLUCustom(torch.nn.Module): 592 def __init__(self, coeff): 593 super().__init__() 594 self.leaky_relu = torch.nn.LeakyReLU(coeff) 595 596 def forward(self, x): 597 return self.leaky_relu(x) 598 599 600class Linear(torch.nn.Module): 601 def __init__(self, use_bias: bool = True): 602 super().__init__() 603 self.linear = torch.nn.Linear(4, 5, use_bias).eval() 604 605 def forward(self, x): 606 return self.linear(x) 607 608 609class LogSoftmax(torch.nn.Module): 610 def __init__(self): 611 super().__init__() 612 613 def forward(self, x): 614 return torch.nn.functional.log_softmax(x, dim=-1) 615 616 617class MaxPool2d(torch.nn.Module): 618 def __init__(self): 619 super().__init__() 620 self.max_pool2d = torch.nn.MaxPool2d( 621 kernel_size=3, 622 stride=1, 623 padding=1, 624 dilation=1, 625 ceil_mode=True, 626 ) 627 628 def forward(self, x): 629 return self.max_pool2d(x) 630 631 632class MeanWKeppDim(torch.nn.Module): 633 def __init__(self): 634 super().__init__() 635 636 def forward(self, x): 637 return torch.mean(x, (-1, -2), keepdim=True) 638 639 640class MeanWOKeppDim(torch.nn.Module): 641 def __init__(self): 642 super().__init__() 643 644 def forward(self, x): 645 return torch.mean(x, (-1, -2)) 646 647 648class Mul(torch.nn.Module): 649 def __init__(self): 650 super().__init__() 651 652 def forward(self, x, y): 653 return torch.mul(x, y) 654 655 656class MulConstantFloat(torch.nn.Module): 657 def __init__(self): 658 super().__init__() 659 660 def forward(self, x): 661 return 10.0 * x 662 663 664class MulConstantLong(torch.nn.Module): 665 def __init__(self): 666 super().__init__() 667 668 def forward(self, x): 669 return 10 * x 670 671 672class MulScalar(torch.nn.Module): 673 def __init__(self): 674 super().__init__() 675 self._scalar = 3.14 676 677 def forward(self, x): 678 out1 = torch.ops.aten.mul.Scalar(x, self._scalar) 679 return out1 680 681 682class MultiheadAttention(torch.nn.Module): 683 def __init__(self): 684 super().__init__() 685 self.multi_head_attention = torch.nn.MultiheadAttention( 686 96, 12, dropout=0.0, batch_first=True 687 ) 688 689 def forward(self, x): 690 attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False) 691 return attn_output 692 693 694class Pad(torch.nn.Module): 695 def __init__(self): 696 super().__init__() 697 698 def forward(self, x): 699 return torch.nn.functional.pad( 700 x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant" 701 ) 702 703 704class PixelShuffle(torch.nn.Module): 705 def __init__(self, scale): 706 super().__init__() 707 self.pixel_shuffle = torch.nn.PixelShuffle(scale) 708 709 def forward(self, x): 710 return self.pixel_shuffle(x) 711 712 713class PixelUnshuffle(torch.nn.Module): 714 def __init__(self, scale): 715 super().__init__() 716 self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale) 717 718 def forward(self, x): 719 return self.pixel_unshuffle(x) 720 721 722class PixelUnshuffleMathEquivalent(torch.nn.Module): 723 def __init__(self, scale): 724 super().__init__() 725 self.scale = scale 726 727 def forward(self, x): 728 b, c, hh, hw = x.size() 729 out_channel = c * (self.scale**2) 730 h = hh // self.scale 731 w = hw // self.scale 732 x_view = x.view(b, c, h, self.scale, w, self.scale) 733 return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 734 735 736class PowTensorScalar(torch.nn.Module): 737 def __init__(self): 738 super().__init__() 739 740 def forward(self, x): 741 return torch.pow(x, 2) 742 743 744class PReLUDefault(torch.nn.Module): 745 def __init__(self): 746 super().__init__() 747 self.prelu = torch.nn.PReLU() 748 749 def forward(self, x): 750 return self.prelu(x) 751 752 753class PReLUPerChannel(torch.nn.Module): 754 def __init__(self, channels): 755 super().__init__() 756 self.prelu = torch.nn.PReLU(channels) 757 758 def forward(self, x): 759 return self.prelu(x) 760 761 762class Relu(torch.nn.Module): 763 def __init__(self): 764 super().__init__() 765 self.relu = torch.nn.ReLU() 766 767 def forward(self, x): 768 return self.relu(x) 769 770 771class Reshape(torch.nn.Module): 772 def __init__(self): 773 super().__init__() 774 775 def forward(self, x): 776 return x.reshape(1, 12) 777 778 779class ResidualBlockModule(torch.nn.Module): 780 def __init__(self): 781 super(ResidualBlockModule, self).__init__() 782 groups = 1 783 stride = [1, 1] 784 padding = [1, 1] 785 dilation = [1, 1] 786 in_channels = 32 787 out_channels = 32 788 789 self.conv = torch.nn.Conv2d( 790 in_channels=in_channels, 791 out_channels=out_channels, 792 kernel_size=(3, 3), 793 stride=stride, 794 padding=padding, 795 groups=groups, 796 dilation=dilation, 797 bias=True, 798 ) 799 self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 800 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0) 801 self.eval() 802 803 def forward(self, x): 804 x1 = self.conv(x) 805 x2 = self.native_batchnorm(x1) 806 x3 = self.conv(x2) 807 x4 = self.native_batchnorm(x3) 808 x5 = self.hardtanh(x4) 809 x6 = torch.add(x5, x2) 810 return x6 811 812 813class ResizeBilinear2D(torch.nn.Module): 814 def __init__(self): 815 super().__init__() 816 817 def forward(self, x): 818 output_shape = [dim * 2 for dim in x.shape[-2:]] 819 return torch.nn.functional.interpolate( 820 x, 821 size=list(torch.randn(output_shape).shape), 822 mode="bilinear", 823 align_corners=False, 824 ) 825 826 827class ResizeNearest2D(torch.nn.Module): 828 def __init__(self): 829 super().__init__() 830 831 def forward(self, x): 832 output_shape = [dim * 2 for dim in x.shape[-2:]] 833 return torch.nn.functional.interpolate( 834 x, 835 size=list(torch.randn(output_shape).shape), 836 mode="nearest", 837 ) 838 839 840class RmsNorm(torch.nn.Module): 841 def __init__(self): 842 super().__init__() 843 self.eps = 1e-5 844 self.rms = torch.nn.RMSNorm([4], 1e-5) 845 846 def forward(self, x): 847 return self.rms(x) 848 849 850class Rsqrt(torch.nn.Module): 851 def __init__(self): 852 super().__init__() 853 854 def forward(self, x): 855 return torch.rsqrt(x) 856 857 858class ScaledDotProductAttention(torch.nn.Module): 859 def __init__(self): 860 super().__init__() 861 862 def forward(self, query_layer, key_layer, value_layer, attn_mask): 863 attn_output = torch.nn.functional.scaled_dot_product_attention( 864 query_layer, key_layer, value_layer, attn_mask 865 ) 866 return attn_output 867 868 869class SelectCopy(torch.nn.Module): 870 def __init__(self): 871 super().__init__() 872 self.conv = torch.nn.Conv2d( 873 in_channels=3, 874 out_channels=2, 875 kernel_size=(3, 3), 876 padding=1, 877 bias=True, 878 ) 879 880 def forward(self, x): 881 return self.conv(x)[0, 1, 1:2] 882 883 884class Sigmoid(torch.nn.Module): 885 def __init__(self): 886 super().__init__() 887 888 def forward(self, x): 889 return torch.sigmoid(x) 890 891 892class SimpleModel(torch.nn.Module): 893 def __init__(self): 894 super().__init__() 895 kernel_sz = 32 896 self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) 897 self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) 898 self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) 899 self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) 900 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 901 self.relu = torch.nn.ReLU() 902 self.batch_norm = torch.nn.BatchNorm2d(kernel_sz) 903 self.add = torch.add 904 self.mean = torch.mean 905 self.reshape = torch.reshape 906 self.linear = torch.nn.Linear(4, 10) 907 self.permute = torch.permute 908 self.eval() 909 910 def forward(self, x, y): 911 x1 = self.conv1(x) 912 x2 = self.batch_norm(x1) 913 x3 = self.relu(x2) 914 x4 = self.conv2(x3) 915 x5 = self.relu(x4) 916 y1 = self.conv3(y) 917 y2 = self.batch_norm(y1) 918 y3 = self.relu(y2) 919 y4 = self.conv4(y3) 920 y5 = self.relu(y4) 921 z = self.add(x5, y5) 922 z1 = self.permute(z, (0, 3, 2, 1)) 923 z2 = torch.mean(z1, [1, 2], True) 924 z3 = self.reshape(z2, (8, -1)) 925 z4 = self.linear(z3) 926 z5 = self.hardtanh(z4) 927 return z5 928 929 930class SliceCopy(torch.nn.Module): 931 def __init__(self): 932 super().__init__() 933 self.position_ids = torch.randn([1, 512]) 934 935 def forward(self, x, y): 936 seq_length = y.size()[1] 937 return x[:, :seq_length] + self.position_ids[:, :seq_length] 938 939 940class SliceCopyWithStep(torch.nn.Module): 941 def __init__(self): 942 super().__init__() 943 self.position_ids = torch.randn([1, 512]) 944 self.step = 2 945 946 def forward(self, x, y): 947 seq_length = y.size()[1] 948 return ( 949 x[:, : seq_length : self.step] 950 + self.position_ids[:, : seq_length : self.step] 951 ) 952 953 954class Softmax(torch.nn.Module): 955 def __init__(self): 956 super().__init__() 957 958 def forward(self, x): 959 return torch.nn.functional.softmax(x, dim=-1) 960 961 962class Sqrt(torch.nn.Module): 963 def __init__(self): 964 super().__init__() 965 966 def forward(self, x): 967 return torch.sqrt(x) 968 969 970class SqrtConstant(torch.nn.Module): 971 def __init__(self): 972 super().__init__() 973 974 def forward(self, x): 975 return x / torch.sqrt(torch.tensor([64.0])) 976 977 978class Squeeze(torch.nn.Module): 979 def __init__(self): 980 super().__init__() 981 982 def forward(self, x): 983 return x.squeeze() 984 985 986class Stack(torch.nn.Module): 987 def __init__(self): 988 super().__init__() 989 990 def forward(self, x, y): 991 return torch.stack((x, y)) 992 993 994class Sub(torch.nn.Module): 995 def __init__(self): 996 super().__init__() 997 998 def forward(self, x, y): 999 return torch.sub(x, y) 1000 1001 1002class SubConstantFloat(torch.nn.Module): 1003 def __init__(self): 1004 super().__init__() 1005 1006 def forward(self, x): 1007 return 10.0 - x 1008 1009 1010class SubConstantLong(torch.nn.Module): 1011 def __init__(self): 1012 super().__init__() 1013 1014 def forward(self, x): 1015 return 10 - x 1016 1017 1018class SumIntList(torch.nn.Module): 1019 def __init__(self): 1020 super().__init__() 1021 1022 def forward(self, x): 1023 return torch.sum(x, dim=(2, 3), keepdim=True) 1024 1025 1026class Tanh(torch.nn.Module): 1027 def __init__(self): 1028 super().__init__() 1029 1030 def forward(self, x): 1031 return torch.tanh(x) 1032 1033 1034class TopKandIndex(torch.nn.Module): 1035 def __init__(self): 1036 super().__init__() 1037 self.idx_source = torch.rand(10, 3) 1038 1039 def forward(self, x): 1040 a, b = torch.topk(x, 3) 1041 return a + self.idx_source[b] 1042 1043 1044class Unbind(torch.nn.Module): 1045 def __init__(self): 1046 super().__init__() 1047 1048 def forward(self, x): 1049 return torch.unbind(x) 1050 1051 1052class Unsqueeze(torch.nn.Module): 1053 def __init__(self): 1054 super().__init__() 1055 1056 def forward(self, x): 1057 return x.unsqueeze(0) 1058 1059 1060class View(torch.nn.Module): 1061 def __init__(self): 1062 super().__init__() 1063 self.first_size = 2 1064 self.second_size = 256 1065 1066 def forward(self, x, y): 1067 new_shape = x.size()[:-1] + (self.first_size, self.second_size) 1068 return x.view(new_shape) 1069 1070 1071class ViewPermuteMatMul(torch.nn.Module): 1072 def __init__(self): 1073 super().__init__() 1074 self.first_size = 2 1075 self.second_size = 256 1076 1077 def forward(self, x, y): 1078 new_shape = x.size()[:-1] + (self.first_size, self.second_size) 1079 x = x.view(new_shape) 1080 x = x.permute(0, 2, 1, 3) 1081 return torch.matmul(x, y.transpose(-1, -2)) 1082