xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 /**
3  * @generated
4  * This is an auto-generated file. Please do not modify it by hand.
5  * To re-generate, please run:
6  * cd ~/pytorch && python
7  * torchgen/shape_functions/gen_jit_shape_functions.py
8  */
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/inliner.h>
11 #include <torch/csrc/jit/runtime/operator.h>
12 #include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
13 
14 // clang-format off
15 
16 namespace torch::jit {
17 
18 
19 std::string shape_funcs = ""
20 + std::string(R"=====(
21 def unary(self: List[int]) -> List[int]:
22   out = annotate(List[int], [])
23   for _0 in range(torch.len(self)):
24     elem = self[_0]
25     _1 = torch.append(out, elem)
26   return out
27 
28 def adaptive_avg_pool2d(self: List[int],
29     out: List[int]) -> List[int]:
30   if torch.eq(torch.len(out), 2):
31     pass
32   else:
33     ops.prim.RaiseException("AssertionError: ")
34   if torch.eq(torch.len(self), 3):
35     _0 = True
36   else:
37     _0 = torch.eq(torch.len(self), 4)
38   if _0:
39     pass
40   else:
41     ops.prim.RaiseException("AssertionError: ")
42   _1 = torch.__range_length(1, torch.len(self), 1)
43   for _2 in range(_1):
44     i = torch.__derive_index(_2, 1, 1)
45     if torch.ne(self[i], 0):
46       pass
47     else:
48       ops.prim.RaiseException("AssertionError: ")
49   shape = annotate(List[int], [])
50   _3 = torch.__range_length(0, torch.sub(torch.len(self), 2), 1)
51   for _4 in range(_3):
52     i0 = torch.__derive_index(_4, 0, 1)
53     _5 = torch.append(shape, self[i0])
54   for _6 in range(torch.len(out)):
55     elem = out[_6]
56     _7 = torch.append(shape, elem)
57   return shape
58 
59 def zero_dim_tensor(input: Any) -> List[int]:
60   return annotate(List[int], [])
61 
62 def arange_end(end: Union[float, int],
63     inp0: Any,
64     inp1: Any,
65     inp2: Any,
66     inp3: Any) -> List[int]:
67   if torch.ge(end, 0):
68     pass
69   else:
70     ops.prim.RaiseException("AssertionError: ")
71   return [int(torch.ceil(end))]
72 
73 def arange_start(start: Union[float, int],
74     end: Union[float, int],
75     inp0: Any,
76     inp1: Any,
77     inp2: Any,
78     inp3: Any) -> List[int]:
79   if torch.ge(end, 0):
80     pass
81   else:
82     ops.prim.RaiseException("AssertionError: ")
83   if torch.ge(end, start):
84     pass
85   else:
86     ops.prim.RaiseException("AssertionError: ")
87   _0 = int(torch.ceil(torch.sub(end, start)))
88   return [_0]
89 
90 )=====")
91 + std::string(R"=====(def arange_start_step(start: Union[float, int],
92     end: Union[float, int],
93     step: Union[float, int],
94     inp0: Any,
95     inp1: Any,
96     inp2: Any,
97     inp3: Any) -> List[int]:
98   if torch.ne(step, 0):
99     pass
100   else:
101     ops.prim.RaiseException("AssertionError: ")
102   if torch.lt(step, 0):
103     if torch.ge(start, end):
104       pass
105     else:
106       ops.prim.RaiseException("AssertionError: ")
107   else:
108     if torch.ge(end, start):
109       pass
110     else:
111       ops.prim.RaiseException("AssertionError: ")
112   _0 = torch.div(torch.sub(end, start), step)
113   return [torch.ceil(_0)]
114 
115 def squeeze_nodim(li: List[int]) -> List[int]:
116   out = annotate(List[int], [])
117   for i in range(torch.len(li)):
118     if torch.ne(li[i], 1):
119       _0 = torch.append(out, li[i])
120     else:
121       pass
122   return out
123 
124 def squeeze(li: List[int],
125     dim: int) -> List[int]:
126   out = annotate(List[int], [])
127   _0 = torch.len(li)
128   if torch.le(_0, 0):
129     dim_post_expr = 1
130   else:
131     dim_post_expr = _0
132   min = torch.neg(dim_post_expr)
133   max = torch.sub(dim_post_expr, 1)
134   if torch.lt(dim, min):
135     _1 = True
136   else:
137     _1 = torch.gt(dim, max)
138   if torch.__not__(_1):
139     pass
140   else:
141     ops.prim.RaiseException("AssertionError: ")
142   if torch.lt(dim, 0):
143     wrapped_dim = torch.add(dim, dim_post_expr)
144   else:
145     wrapped_dim = dim
146   for i in range(torch.len(li)):
147     if torch.eq(i, wrapped_dim):
148       if torch.ne(li[i], 1):
149         _2 = torch.append(out, li[i])
150       else:
151         pass
152     else:
153       _3 = torch.append(out, li[i])
154   return out
155 
156 )=====")
157 + std::string(R"=====(def squeeze_dims(li: List[int],
158     dims: List[int]) -> List[int]:
159   if torch.eq(torch.len(dims), 0):
160     _0 = li
161   else:
162     wrapped_dims = annotate(List[int], [])
163     for _1 in range(torch.len(dims)):
164       elem = dims[_1]
165       _2 = torch.append(wrapped_dims, elem)
166     for i in range(torch.len(dims)):
167       _3 = wrapped_dims[i]
168       _4 = torch.len(li)
169       if torch.le(_4, 0):
170         dim_post_expr = 1
171       else:
172         dim_post_expr = _4
173       min = torch.neg(dim_post_expr)
174       max = torch.sub(dim_post_expr, 1)
175       if torch.lt(_3, min):
176         _5 = True
177       else:
178         _5 = torch.gt(_3, max)
179       if torch.__not__(_5):
180         pass
181       else:
182         ops.prim.RaiseException("AssertionError: ")
183       if torch.lt(_3, 0):
184         dim = torch.add(_3, dim_post_expr)
185       else:
186         dim = _3
187       _6 = torch._set_item(wrapped_dims, i, dim)
188     result = annotate(List[int], [])
189     for i0 in range(torch.len(li)):
190       if torch.eq(li[i0], 1):
191         _7 = torch.__contains__(wrapped_dims, i0)
192         if torch.__not__(_7):
193           _8 = torch.append(result, li[i0])
194         else:
195           pass
196       else:
197         _9 = torch.append(result, li[i0])
198     _0 = result
199   return _0
200 
201 def unsqueeze(li: List[int],
202     dim: int) -> List[int]:
203   _0 = torch.add(torch.len(li), 1)
204   if torch.le(_0, 0):
205     dim_post_expr = 1
206   else:
207     dim_post_expr = _0
208   min = torch.neg(dim_post_expr)
209   max = torch.sub(dim_post_expr, 1)
210   if torch.lt(dim, min):
211     _1 = True
212   else:
213     _1 = torch.gt(dim, max)
214   if torch.__not__(_1):
215     pass
216   else:
217     ops.prim.RaiseException("AssertionError: ")
218   if torch.lt(dim, 0):
219     dim0 = torch.add(dim, dim_post_expr)
220   else:
221     dim0 = dim
222   out = annotate(List[int], [])
223   for _2 in range(torch.len(li)):
224     elem = li[_2]
225     _3 = torch.append(out, elem)
226   torch.insert(out, dim0, 1)
227   return out
228 
229 )=====")
230 + std::string(R"=====(def slice(self: List[int],
231     dim: int,
232     start: Optional[int],
233     end: Optional[int],
234     step: int) -> List[int]:
235   ndim = torch.len(self)
236   if torch.ne(ndim, 0):
237     pass
238   else:
239     ops.prim.RaiseException("AssertionError: ")
240   if torch.le(ndim, 0):
241     dim_post_expr = 1
242   else:
243     dim_post_expr = ndim
244   min = torch.neg(dim_post_expr)
245   max = torch.sub(dim_post_expr, 1)
246   if torch.lt(dim, min):
247     _0 = True
248   else:
249     _0 = torch.gt(dim, max)
250   if torch.__not__(_0):
251     pass
252   else:
253     ops.prim.RaiseException("AssertionError: ")
254   if torch.lt(dim, 0):
255     dim0 = torch.add(dim, dim_post_expr)
256   else:
257     dim0 = dim
258   if torch.__isnot__(start, None):
259     start_val = unchecked_cast(int, start)
260   else:
261     start_val = 0
262   if torch.__isnot__(end, None):
263     end_val = unchecked_cast(int, end)
264   else:
265     end_val = 9223372036854775807
266   if torch.gt(step, 0):
267     pass
268   else:
269     ops.prim.RaiseException("AssertionError: ")
270   _1 = torch.eq(start_val, 9223372036854775807)
271   if _1:
272     start_val0 = 0
273   else:
274     start_val0 = start_val
275   if torch.lt(start_val0, 0):
276     start_val1 = torch.add(start_val0, self[dim0])
277   else:
278     start_val1 = start_val0
279   if torch.lt(end_val, 0):
280     end_val0 = torch.add(end_val, self[dim0])
281   else:
282     end_val0 = end_val
283   if torch.lt(start_val1, 0):
284     start_val2 = 0
285   else:
286     if torch.gt(start_val1, self[dim0]):
287       start_val3 = self[dim0]
288     else:
289       start_val3 = start_val1
290     start_val2 = start_val3
291   if torch.lt(end_val0, start_val2):
292     end_val1 = start_val2
293   else:
294     if torch.ge(end_val0, self[dim0]):
295       end_val2 = self[dim0]
296     else:
297       end_val2 = end_val0
298     end_val1 = end_val2
299   slice_len = torch.sub(end_val1, start_val2)
300   out = annotate(List[int], [])
301   for _2 in range(torch.len(self)):
302     elem = self[_2]
303     _3 = torch.append(out, elem)
304   _4 = torch.sub(torch.add(slice_len, step), 1)
305   _5 = torch._set_item(out, dim0, torch.floordiv(_4, step))
306   return out
307 
308 )=====")
309 + std::string(R"=====(def select(self: List[int],
310     dim: int,
311     index: int) -> List[int]:
312   ndim = torch.len(self)
313   if torch.ne(ndim, 0):
314     pass
315   else:
316     ops.prim.RaiseException("AssertionError: ")
317   if torch.le(ndim, 0):
318     dim_post_expr = 1
319   else:
320     dim_post_expr = ndim
321   min = torch.neg(dim_post_expr)
322   max = torch.sub(dim_post_expr, 1)
323   if torch.lt(dim, min):
324     _0 = True
325   else:
326     _0 = torch.gt(dim, max)
327   if torch.__not__(_0):
328     pass
329   else:
330     ops.prim.RaiseException("AssertionError: ")
331   if torch.lt(dim, 0):
332     dim0 = torch.add(dim, dim_post_expr)
333   else:
334     dim0 = dim
335   size = self[dim0]
336   if torch.lt(index, torch.neg(size)):
337     _1 = True
338   else:
339     _1 = torch.ge(index, size)
340   if torch.__not__(_1):
341     pass
342   else:
343     ops.prim.RaiseException("AssertionError: ")
344   out = annotate(List[int], [])
345   for i in range(ndim):
346     if torch.ne(i, dim0):
347       _2 = torch.append(out, self[i])
348     else:
349       pass
350   return out
351 
352 )=====")
353 + std::string(R"=====(def index_select(self: List[int],
354     dim: int,
355     index: List[int]) -> List[int]:
356   _0 = torch.len(self)
357   if torch.le(_0, 0):
358     dim_post_expr = 1
359   else:
360     dim_post_expr = _0
361   min = torch.neg(dim_post_expr)
362   max = torch.sub(dim_post_expr, 1)
363   if torch.lt(dim, min):
364     _1 = True
365   else:
366     _1 = torch.gt(dim, max)
367   if torch.__not__(_1):
368     pass
369   else:
370     ops.prim.RaiseException("AssertionError: ")
371   if torch.lt(dim, 0):
372     dim0 = torch.add(dim, dim_post_expr)
373   else:
374     dim0 = dim
375   numel = 1
376   for _2 in range(torch.len(index)):
377     elem = index[_2]
378     numel = torch.mul(numel, elem)
379   if torch.le(torch.len(index), 1):
380     pass
381   else:
382     ops.prim.RaiseException("AssertionError: ")
383   if torch.eq(dim0, 0):
384     _3 = True
385   else:
386     _3 = torch.lt(dim0, torch.len(self))
387   if _3:
388     pass
389   else:
390     ops.prim.RaiseException("AssertionError: ")
391   result_size = annotate(List[int], [])
392   for i in range(torch.len(self)):
393     if torch.eq(dim0, i):
394       _4 = torch.append(result_size, numel)
395     else:
396       _5 = torch.append(result_size, self[i])
397   return result_size
398 
399 )=====")
400 + std::string(R"=====(def embedding(weight: List[int],
401     indices: List[int],
402     padding_idx: int=-1,
403     scale_grad_by_freq: bool=False,
404     sparse: bool=False) -> List[int]:
405   if torch.eq(torch.len(weight), 2):
406     pass
407   else:
408     ops.prim.RaiseException("AssertionError: ")
409   if torch.eq(torch.len(indices), 1):
410     _1 = torch.len(weight)
411     if torch.le(_1, 0):
412       dim_post_expr = 1
413     else:
414       dim_post_expr = _1
415     min = torch.neg(dim_post_expr)
416     max = torch.sub(dim_post_expr, 1)
417     if torch.lt(0, min):
418       _2 = True
419     else:
420       _2 = torch.gt(0, max)
421     if torch.__not__(_2):
422       pass
423     else:
424       ops.prim.RaiseException("AssertionError: ")
425     numel = 1
426     for _3 in range(torch.len(indices)):
427       elem = indices[_3]
428       numel = torch.mul(numel, elem)
429     if torch.le(torch.len(indices), 1):
430       pass
431     else:
432       ops.prim.RaiseException("AssertionError: ")
433     result_size = annotate(List[int], [])
434     for i in range(torch.len(weight)):
435       if torch.eq(0, i):
436         _4 = torch.append(result_size, numel)
437       else:
438         _5 = torch.append(result_size, weight[i])
439     _0 = result_size
440   else:
441     size = annotate(List[int], [])
442     for _6 in range(torch.len(indices)):
443       elem0 = indices[_6]
444       _7 = torch.append(size, elem0)
445     _8 = torch.append(size, weight[1])
446     _0 = size
447   return _0
448 
449 def mm(self: List[int],
450     mat2: List[int]) -> List[int]:
451   _0 = "AssertionError: self must be a matrix"
452   _1 = "AssertionError: mat2 must be a matrix"
453   if torch.eq(torch.len(self), 2):
454     pass
455   else:
456     ops.prim.RaiseException(_0)
457   if torch.eq(torch.len(mat2), 2):
458     pass
459   else:
460     ops.prim.RaiseException(_1)
461   if torch.eq(self[1], mat2[0]):
462     pass
463   else:
464     ops.prim.RaiseException("AssertionError: ")
465   return [self[0], mat2[1]]
466 
467 )=====")
468 + std::string(R"=====(def dot(self: List[int],
469     tensor: List[int]) -> List[int]:
470   if torch.eq(torch.len(self), 1):
471     _0 = torch.eq(torch.len(tensor), 1)
472   else:
473     _0 = False
474   if _0:
475     pass
476   else:
477     ops.prim.RaiseException("AssertionError: ")
478   if torch.eq(self[0], tensor[0]):
479     pass
480   else:
481     ops.prim.RaiseException("AssertionError: ")
482   return annotate(List[int], [])
483 
484 def mv(self: List[int],
485     vec: List[int]) -> List[int]:
486   if torch.eq(torch.len(self), 2):
487     _0 = torch.eq(torch.len(vec), 1)
488   else:
489     _0 = False
490   if _0:
491     pass
492   else:
493     ops.prim.RaiseException("AssertionError: ")
494   if torch.eq(self[1], vec[0]):
495     pass
496   else:
497     ops.prim.RaiseException("AssertionError: ")
498   return [self[0]]
499 
500 )=====")
501 + std::string(R"=====(def matmul(tensor1: List[int],
502     tensor2: List[int]) -> List[int]:
503   _0 = "AssertionError: self must be a matrix"
504   _1 = "AssertionError: mat2 must be a matrix"
505   _2 = "AssertionError: self must be a matrix"
506   _3 = "AssertionError: mat2 must be a matrix"
507   _4 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
508   _5 = "AssertionError: both  arguments to matmul need to be at least 1D"
509   _6 = uninitialized(List[int])
510   dim_tensor1 = torch.len(tensor1)
511   dim_tensor2 = torch.len(tensor2)
512   if torch.eq(dim_tensor1, 1):
513     _7 = torch.eq(dim_tensor2, 1)
514   else:
515     _7 = False
516   if _7:
517     if torch.eq(torch.len(tensor1), 1):
518       _9 = torch.eq(torch.len(tensor2), 1)
519     else:
520       _9 = False
521     if _9:
522       pass
523     else:
524       ops.prim.RaiseException("AssertionError: ")
525     if torch.eq(tensor1[0], tensor2[0]):
526       pass
527     else:
528       ops.prim.RaiseException("AssertionError: ")
529     _8 = annotate(List[int], [])
530   else:
531     if torch.eq(dim_tensor1, 2):
532       _10 = torch.eq(dim_tensor2, 1)
533     else:
534       _10 = False
535     if _10:
536       if torch.eq(torch.len(tensor1), 2):
537         _12 = torch.eq(torch.len(tensor2), 1)
538       else:
539         _12 = False
540       if _12:
541         pass
542       else:
543         ops.prim.RaiseException("AssertionError: ")
544       if torch.eq(tensor1[1], tensor2[0]):
545         pass
546       else:
547         ops.prim.RaiseException("AssertionError: ")
548       _11 = [tensor1[0]]
549     else:
550       if torch.eq(dim_tensor1, 1):
551         _13 = torch.eq(dim_tensor2, 2)
552       else:
553         _13 = False
554       if _13:
555         _15 = torch.add(torch.len(tensor1), 1)
556         if torch.le(_15, 0):
557           dim_post_expr = 1
558         else:
559           dim_post_expr = _15
560         min = torch.neg(dim_post_expr)
561         max = torch.sub(dim_post_expr, 1)
562         if torch.lt(0, min):
563           _16 = True
564         else:
565           _16 = torch.gt(0, max)
566         if torch.__not__(_16):
567           pass
568         else:
569           ops.prim.RaiseException("AssertionError: ")
570         out = annotate(List[int], [])
571         for _17 in range(torch.len(tensor1)):
572           elem = tensor1[_17]
573           _18 = torch.append(out, elem)
574         torch.insert(out, 0, 1)
575         if torch.eq(torch.len(out), 2):
576           pass
577         else:
578           ops.prim.RaiseException(_0)
579         if torch.eq(torch.len(tensor2), 2):
580           pass
581         else:
582           ops.prim.RaiseException(_1)
583         if torch.eq(out[1], tensor2[0]):
584           pass
585         else:
586           ops.prim.RaiseException("AssertionError: ")
587         _19 = [out[0], tensor2[1]]
588         out0 = annotate(List[int], [])
589         for i in range(2):
590           if torch.eq(i, 0):
591             if torch.ne(_19[i], 1):
592               _20 = torch.append(out0, _19[i])
593             else:
594               pass
595           else:
596             _21 = torch.append(out0, _19[i])
597         _14 = out0
598       else:
599         if torch.eq(dim_tensor1, 2):
600           _22 = torch.eq(dim_tensor2, 2)
601         else:
602           _22 = False
603         if _22:
604           _24 = torch.eq(torch.len(tensor1), 2)
605           if _24:
606             pass
607           else:
608             ops.prim.RaiseException(_2)
609           _25 = torch.eq(torch.len(tensor2), 2)
610           if _25:
611             pass
612           else:
613             ops.prim.RaiseException(_3)
614           _26 = torch.eq(tensor1[1], tensor2[0])
615           if _26:
616             pass
617           else:
618             ops.prim.RaiseException("AssertionError: ")
619           _23 = [tensor1[0], tensor2[1]]
620         else:
621           if torch.ge(dim_tensor1, 1):
622             _27 = torch.ge(dim_tensor2, 1)
623           else:
624             _27 = False
625           if _27:
626             if torch.gt(dim_tensor1, 1):
627               n = tensor1[-2]
628             else:
629               n = 1
630             batch_tensor1 = annotate(List[int], [])
631             for i0 in range(torch.sub(dim_tensor1, 2)):
632               _29 = torch.append(batch_tensor1, tensor1[i0])
633             p = tensor2[-1]
634             batch_tensor2 = annotate(List[int], [])
635             for i1 in range(torch.sub(dim_tensor2, 2)):
636               _30 = torch.append(batch_tensor2, tensor2[i1])
637             dimsA = torch.len(batch_tensor1)
638             dimsB = torch.len(batch_tensor2)
639             ndim = ops.prim.max(dimsA, dimsB)
640             expand_batch_portion = annotate(List[int], [])
641             for i2 in range(ndim):
642               offset = torch.sub(torch.sub(ndim, 1), i2)
643               dimA = torch.sub(torch.sub(dimsA, 1), offset)
644               dimB = torch.sub(torch.sub(dimsB, 1), offset)
645               if torch.ge(dimA, 0):
646                 sizeA = batch_tensor1[dimA]
647               else:
648                 sizeA = 1
649               if torch.ge(dimB, 0):
650                 sizeB = batch_tensor2[dimB]
651               else:
652                 sizeB = 1
653               if torch.ne(sizeA, sizeB):
654                 _31 = torch.ne(sizeA, 1)
655               else:
656                 _31 = False
657               if _31:
658                 _32 = torch.ne(sizeB, 1)
659               else:
660                 _32 = False
661               if _32:
662                 _33 = torch.format(_4, sizeA, sizeB, i2)
663                 _34 = torch.add("AssertionError: ", _33)
664                 ops.prim.RaiseException(_34)
665               else:
666                 pass
667               if torch.eq(sizeA, 1):
668                 _35 = sizeB
669               else:
670                 _35 = sizeA
671               _36 = torch.append(expand_batch_portion, _35)
672             if torch.gt(dim_tensor1, 1):
673               _37 = torch.append(expand_batch_portion, n)
674             else:
675               pass
676             if torch.gt(dim_tensor2, 1):
677               _38 = torch.append(expand_batch_portion, p)
678             else:
679               pass
680             _28 = expand_batch_portion
681           else:
682             ops.prim.RaiseException(_5)
683             _28 = _6
684           _23 = _28
685         _14 = _23
686       _11 = _14
687     _8 = _11
688   return _8
689 
690 )=====")
691 + std::string(R"=====(def linear(input: List[int],
692     weight: List[int],
693     bias: Optional[List[int]]) -> List[int]:
694   _0 = "AssertionError: self must be a matrix"
695   _1 = "AssertionError: mat2 must be a matrix"
696   _2 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
697   _3 = "AssertionError: both  arguments to matmul need to be at least 1D"
698   _4 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
699   if torch.le(torch.len(weight), 2):
700     pass
701   else:
702     ops.prim.RaiseException("AssertionError: ")
703   self_len = torch.len(weight)
704   if torch.eq(self_len, 0):
705     _5 = annotate(List[int], [])
706   else:
707     if torch.eq(self_len, 1):
708       _6 = [weight[0]]
709     else:
710       _6 = [weight[1], weight[0]]
711     _5 = _6
712   _7 = uninitialized(List[int])
713   dim_tensor1 = torch.len(input)
714   dim_tensor2 = torch.len(_5)
715   if torch.eq(dim_tensor1, 1):
716     _8 = torch.eq(dim_tensor2, 1)
717   else:
718     _8 = False
719   if _8:
720     if torch.eq(torch.len(input), 1):
721       _9 = torch.eq(torch.len(_5), 1)
722     else:
723       _9 = False
724     if _9:
725       pass
726     else:
727       ops.prim.RaiseException("AssertionError: ")
728     if torch.eq(input[0], _5[0]):
729       pass
730     else:
731       ops.prim.RaiseException("AssertionError: ")
732     out = annotate(List[int], [])
733   else:
734     if torch.eq(dim_tensor1, 2):
735       _10 = torch.eq(dim_tensor2, 1)
736     else:
737       _10 = False
738     if _10:
739       if torch.eq(torch.len(input), 2):
740         _12 = torch.eq(torch.len(_5), 1)
741       else:
742         _12 = False
743       if _12:
744         pass
745       else:
746         ops.prim.RaiseException("AssertionError: ")
747       if torch.eq(input[1], _5[0]):
748         pass
749       else:
750         ops.prim.RaiseException("AssertionError: ")
751       _11 = [input[0]]
752     else:
753       if torch.eq(dim_tensor1, 1):
754         _13 = torch.eq(dim_tensor2, 2)
755       else:
756         _13 = False
757       if _13:
758         _15 = torch.add(torch.len(input), 1)
759         if torch.le(_15, 0):
760           dim_post_expr = 1
761         else:
762           dim_post_expr = _15
763         min = torch.neg(dim_post_expr)
764         max = torch.sub(dim_post_expr, 1)
765         if torch.lt(0, min):
766           _16 = True
767         else:
768           _16 = torch.gt(0, max)
769         if torch.__not__(_16):
770           pass
771         else:
772           ops.prim.RaiseException("AssertionError: ")
773         out0 = annotate(List[int], [])
774         for _17 in range(torch.len(input)):
775           elem = input[_17]
776           _18 = torch.append(out0, elem)
777         torch.insert(out0, 0, 1)
778         if torch.eq(torch.len(out0), 2):
779           pass
780         else:
781           ops.prim.RaiseException(_0)
782         if torch.eq(torch.len(_5), 2):
783           pass
784         else:
785           ops.prim.RaiseException(_1)
786         if torch.eq(out0[1], _5[0]):
787           pass
788         else:
789           ops.prim.RaiseException("AssertionError: ")
790         _19 = [out0[0], _5[1]]
791         out1 = annotate(List[int], [])
792         for i in range(2):
793           if torch.eq(i, 0):
794             if torch.ne(_19[i], 1):
795               _20 = torch.append(out1, _19[i])
796             else:
797               pass
798           else:
799             _21 = torch.append(out1, _19[i])
800         _14 = out1
801       else:
802         if torch.eq(dim_tensor1, 2):
803           _22 = torch.eq(dim_tensor2, 2)
804         else:
805           _22 = False
806         if _22:
807           if torch.eq(torch.len(input), 2):
808             pass
809           else:
810             ops.prim.RaiseException(_0)
811           if torch.eq(torch.len(_5), 2):
812             pass
813           else:
814             ops.prim.RaiseException(_1)
815           if torch.eq(input[1], _5[0]):
816             pass
817           else:
818             ops.prim.RaiseException("AssertionError: ")
819           _23 = [input[0], _5[1]]
820         else:
821           if torch.ge(dim_tensor1, 1):
822             _24 = torch.ge(dim_tensor2, 1)
823           else:
824             _24 = False
825           if _24:
826             if torch.gt(dim_tensor1, 1):
827               n = input[-2]
828             else:
829               n = 1
830             batch_tensor1 = annotate(List[int], [])
831             for i0 in range(torch.sub(dim_tensor1, 2)):
832               _26 = torch.append(batch_tensor1, input[i0])
833             p = _5[-1]
834             batch_tensor2 = annotate(List[int], [])
835             for i1 in range(torch.sub(dim_tensor2, 2)):
836               _27 = torch.append(batch_tensor2, _5[i1])
837             dimsA = torch.len(batch_tensor1)
838             dimsB = torch.len(batch_tensor2)
839             ndim = ops.prim.max(dimsA, dimsB)
840             expand_batch_portion = annotate(List[int], [])
841             for i2 in range(ndim):
842               offset = torch.sub(torch.sub(ndim, 1), i2)
843               dimA = torch.sub(torch.sub(dimsA, 1), offset)
844               dimB = torch.sub(torch.sub(dimsB, 1), offset)
845               if torch.ge(dimA, 0):
846                 sizeA = batch_tensor1[dimA]
847               else:
848                 sizeA = 1
849               if torch.ge(dimB, 0):
850                 sizeB = batch_tensor2[dimB]
851               else:
852                 sizeB = 1
853               if torch.ne(sizeA, sizeB):
854                 _28 = torch.ne(sizeA, 1)
855               else:
856                 _28 = False
857               if _28:
858                 _29 = torch.ne(sizeB, 1)
859               else:
860                 _29 = False
861               if _29:
862                 _30 = torch.format(_2, sizeA, sizeB, i2)
863                 _31 = torch.add("AssertionError: ", _30)
864                 ops.prim.RaiseException(_31)
865               else:
866                 pass
867               if torch.eq(sizeA, 1):
868                 _32 = sizeB
869               else:
870                 _32 = sizeA
871               _33 = torch.append(expand_batch_portion, _32)
872             if torch.gt(dim_tensor1, 1):
873               _34 = torch.append(expand_batch_portion, n)
874             else:
875               pass
876             if torch.gt(dim_tensor2, 1):
877               _35 = torch.append(expand_batch_portion, p)
878             else:
879               pass
880             _25 = expand_batch_portion
881           else:
882             ops.prim.RaiseException(_3)
883             _25 = _7
884           _23 = _25
885         _14 = _23
886       _11 = _14
887     out = _11
888   if torch.__isnot__(bias, None):
889     bias0 = unchecked_cast(List[int], bias)
890     dimsA0 = torch.len(bias0)
891     dimsB0 = torch.len(out)
892     ndim0 = ops.prim.max(dimsA0, dimsB0)
893     expandedSizes = annotate(List[int], [])
894     for i3 in range(ndim0):
895       offset0 = torch.sub(torch.sub(ndim0, 1), i3)
896       dimA0 = torch.sub(torch.sub(dimsA0, 1), offset0)
897       dimB0 = torch.sub(torch.sub(dimsB0, 1), offset0)
898       if torch.ge(dimA0, 0):
899         sizeA0 = bias0[dimA0]
900       else:
901         sizeA0 = 1
902       if torch.ge(dimB0, 0):
903         sizeB0 = out[dimB0]
904       else:
905         sizeB0 = 1
906       if torch.ne(sizeA0, sizeB0):
907         _36 = torch.ne(sizeA0, 1)
908       else:
909         _36 = False
910       if _36:
911         _37 = torch.ne(sizeB0, 1)
912       else:
913         _37 = False
914       if _37:
915         _38 = torch.format(_4, sizeA0, sizeB0, i3)
916         _39 = torch.add("AssertionError: ", _38)
917         ops.prim.RaiseException(_39)
918       else:
919         pass
920       if torch.eq(sizeA0, 1):
921         _40 = sizeB0
922       else:
923         _40 = sizeA0
924       _41 = torch.append(expandedSizes, _40)
925     if torch.eq(expandedSizes, out):
926       pass
927     else:
928       ops.prim.RaiseException("AssertionError: ")
929   else:
930     pass
931   return out
932 
933 )=====")
934 + std::string(R"=====(def max_pool2d(input: List[int],
935     kernel_size: List[int],
936     stride: List[int],
937     padding: List[int],
938     dilation: List[int],
939     ceil_mode: bool) -> List[int]:
940   _0 = "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
941   _1 = "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
942   _2 = "AssertionError: max_pool2d: padding must either be a single int, or a tuple of two ints"
943   _3 = "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints"
944   _4 = "AssertionError: stride should not be zeero"
945   _5 = "AssertionError: stride should not be zeero"
946   if torch.eq(torch.len(kernel_size), 1):
947     _6 = True
948   else:
949     _6 = torch.eq(torch.len(kernel_size), 2)
950   if _6:
951     pass
952   else:
953     ops.prim.RaiseException(_0)
954   kH = kernel_size[0]
955   if torch.eq(torch.len(kernel_size), 1):
956     kW = kH
957   else:
958     kW = kernel_size[1]
959   if torch.eq(torch.len(stride), 0):
960     _7 = True
961   else:
962     _7 = torch.eq(torch.len(stride), 1)
963   if _7:
964     _8 = True
965   else:
966     _8 = torch.eq(torch.len(stride), 2)
967   if _8:
968     pass
969   else:
970     ops.prim.RaiseException(_1)
971   if torch.eq(torch.len(stride), 0):
972     dH = kH
973   else:
974     dH = stride[0]
975   if torch.eq(torch.len(stride), 0):
976     dW = kW
977   else:
978     if torch.eq(torch.len(stride), 1):
979       dW0 = dH
980     else:
981       dW0 = stride[1]
982     dW = dW0
983   if torch.eq(torch.len(padding), 1):
984     _9 = True
985   else:
986     _9 = torch.eq(torch.len(padding), 2)
987   if _9:
988     pass
989   else:
990     ops.prim.RaiseException(_2)
991   padH = padding[0]
992   if torch.eq(torch.len(padding), 1):
993     padW = padH
994   else:
995     padW = padding[1]
996   if torch.eq(torch.len(dilation), 1):
997     _10 = True
998   else:
999     _10 = torch.eq(torch.len(dilation), 2)
1000   if _10:
1001     pass
1002   else:
1003     ops.prim.RaiseException(_3)
1004   dilationH = dilation[0]
1005   if torch.eq(torch.len(dilation), 1):
1006     dilationW = dilationH
1007   else:
1008     dilationW = dilation[1]
1009   if torch.eq(torch.len(input), 3):
1010     _11 = True
1011   else:
1012     _11 = torch.eq(torch.len(input), 4)
1013   if _11:
1014     pass
1015   else:
1016     ops.prim.RaiseException("AssertionError: ")
1017   if torch.eq(torch.len(input), 4):
1018     nbatch = input[-4]
1019   else:
1020     nbatch = 1
1021   nInputPlane = input[-3]
1022   inputHeight = input[-2]
1023   inputWidth = input[-1]
1024   if torch.ne(dH, 0):
1025     pass
1026   else:
1027     ops.prim.RaiseException(_4)
1028   _12 = torch.add(torch.add(inputHeight, padH), padH)
1029   _13 = torch.mul(dilationH, torch.sub(kH, 1))
1030   _14 = torch.sub(torch.sub(_12, _13), 1)
1031   if ceil_mode:
1032     _15 = torch.sub(dH, 1)
1033   else:
1034     _15 = 0
1035   _16 = torch.floordiv(torch.add(_14, _15), dH)
1036   outputSize = torch.add(_16, 1)
1037   if ceil_mode:
1038     _17 = torch.ge(torch.mul(_16, dH), torch.add(inputHeight, padH))
1039     if _17:
1040       outputSize0 = _16
1041     else:
1042       outputSize0 = outputSize
1043     outputHeight = outputSize0
1044   else:
1045     outputHeight = outputSize
1046   if torch.ne(dW, 0):
1047     pass
1048   else:
1049     ops.prim.RaiseException(_5)
1050   _18 = torch.add(torch.add(inputWidth, padW), padW)
1051   _19 = torch.mul(dilationW, torch.sub(kW, 1))
1052   _20 = torch.sub(torch.sub(_18, _19), 1)
1053   if ceil_mode:
1054     _21 = torch.sub(dW, 1)
1055   else:
1056     _21 = 0
1057   _22 = torch.floordiv(torch.add(_20, _21), dW)
1058   outputSize1 = torch.add(_22, 1)
1059   if ceil_mode:
1060     _23 = torch.ge(torch.mul(_22, dW), torch.add(inputWidth, padW))
1061     if _23:
1062       outputSize2 = _22
1063     else:
1064       outputSize2 = outputSize1
1065     outputWidth = outputSize2
1066   else:
1067     outputWidth = outputSize1
1068   ndim = torch.len(input)
1069   if torch.gt(kW, 0):
1070     _24 = torch.gt(kH, 0)
1071   else:
1072     _24 = False
1073   if _24:
1074     pass
1075   else:
1076     ops.prim.RaiseException("AssertionError: ")
1077   if torch.gt(dW, 0):
1078     _25 = torch.gt(dH, 0)
1079   else:
1080     _25 = False
1081   if _25:
1082     pass
1083   else:
1084     ops.prim.RaiseException("AssertionError: ")
1085   if torch.gt(dilationH, 0):
1086     _26 = torch.gt(dilationW, 0)
1087   else:
1088     _26 = False
1089   if _26:
1090     pass
1091   else:
1092     ops.prim.RaiseException("AssertionError: ")
1093   if torch.ne(input[1], 0):
1094     valid_dims = torch.ne(input[2], 0)
1095   else:
1096     valid_dims = False
1097   if torch.eq(ndim, 3):
1098     _27 = torch.ne(input[0], 0)
1099   else:
1100     _27 = False
1101   if _27:
1102     _28 = valid_dims
1103   else:
1104     _28 = False
1105   if _28:
1106     _29 = True
1107   else:
1108     if torch.eq(ndim, 4):
1109       _30 = valid_dims
1110     else:
1111       _30 = False
1112     if _30:
1113       _31 = torch.ne(input[3], 0)
1114     else:
1115       _31 = False
1116     _29 = _31
1117   if _29:
1118     pass
1119   else:
1120     ops.prim.RaiseException("AssertionError: ")
1121   if torch.ge(torch.floordiv(kW, 2), padW):
1122     _33 = torch.ge(torch.floordiv(kH, 2), padH)
1123     _32 = _33
1124   else:
1125     _32 = False
1126   if _32:
1127     pass
1128   else:
1129     ops.prim.RaiseException("AssertionError: ")
1130   if torch.ge(outputWidth, 1):
1131     _34 = torch.ge(outputHeight, 1)
1132   else:
1133     _34 = False
1134   if _34:
1135     pass
1136   else:
1137     ops.prim.RaiseException("AssertionError: ")
1138   if torch.eq(torch.len(input), 3):
1139     _36 = [nInputPlane, outputHeight, outputWidth]
1140     _35 = _36
1141   else:
1142     _37 = [nbatch, nInputPlane, outputHeight, outputWidth]
1143     _35 = _37
1144   return _35
1145 
1146 )=====")
1147 + std::string(R"=====(def max_pool2d_with_indices(input: List[int],
1148     kernel_size: List[int],
1149     stride: List[int],
1150     padding: List[int],
1151     dilation: List[int],
1152     ceil_mode: bool) -> Tuple[List[int], List[int]]:
1153   _0 = "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
1154   _1 = "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
1155   _2 = "AssertionError: max_pool2d: padding must either be a single int, or a tuple of two ints"
1156   _3 = "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints"
1157   _4 = "AssertionError: stride should not be zeero"
1158   if torch.eq(torch.len(kernel_size), 1):
1159     _5 = True
1160   else:
1161     _5 = torch.eq(torch.len(kernel_size), 2)
1162   if _5:
1163     pass
1164   else:
1165     ops.prim.RaiseException(_0)
1166   kH = kernel_size[0]
1167   if torch.eq(torch.len(kernel_size), 1):
1168     kW = kH
1169   else:
1170     kW = kernel_size[1]
1171   if torch.eq(torch.len(stride), 0):
1172     _6 = True
1173   else:
1174     _6 = torch.eq(torch.len(stride), 1)
1175   if _6:
1176     _7 = True
1177   else:
1178     _7 = torch.eq(torch.len(stride), 2)
1179   if _7:
1180     pass
1181   else:
1182     ops.prim.RaiseException(_1)
1183   if torch.eq(torch.len(stride), 0):
1184     dH = kH
1185   else:
1186     dH = stride[0]
1187   if torch.eq(torch.len(stride), 0):
1188     dW = kW
1189   else:
1190     if torch.eq(torch.len(stride), 1):
1191       dW0 = dH
1192     else:
1193       dW0 = stride[1]
1194     dW = dW0
1195   if torch.eq(torch.len(padding), 1):
1196     _8 = True
1197   else:
1198     _8 = torch.eq(torch.len(padding), 2)
1199   if _8:
1200     pass
1201   else:
1202     ops.prim.RaiseException(_2)
1203   padH = padding[0]
1204   if torch.eq(torch.len(padding), 1):
1205     padW = padH
1206   else:
1207     padW = padding[1]
1208   if torch.eq(torch.len(dilation), 1):
1209     _9 = True
1210   else:
1211     _9 = torch.eq(torch.len(dilation), 2)
1212   if _9:
1213     pass
1214   else:
1215     ops.prim.RaiseException(_3)
1216   dilationH = dilation[0]
1217   if torch.eq(torch.len(dilation), 1):
1218     dilationW = dilationH
1219   else:
1220     dilationW = dilation[1]
1221   if torch.eq(torch.len(input), 3):
1222     _10 = True
1223   else:
1224     _10 = torch.eq(torch.len(input), 4)
1225   if _10:
1226     pass
1227   else:
1228     ops.prim.RaiseException("AssertionError: ")
1229   if torch.eq(torch.len(input), 4):
1230     nbatch = input[-4]
1231   else:
1232     nbatch = 1
1233   nInputPlane = input[-3]
1234   inputHeight = input[-2]
1235   inputWidth = input[-1]
1236   if torch.ne(dH, 0):
1237     pass
1238   else:
1239     ops.prim.RaiseException(_4)
1240   _11 = torch.add(torch.add(inputHeight, padH), padH)
1241   _12 = torch.mul(dilationH, torch.sub(kH, 1))
1242   _13 = torch.sub(torch.sub(_11, _12), 1)
1243   if ceil_mode:
1244     _14 = torch.sub(dH, 1)
1245   else:
1246     _14 = 0
1247   _15 = torch.floordiv(torch.add(_13, _14), dH)
1248   outputSize = torch.add(_15, 1)
1249   if ceil_mode:
1250     _16 = torch.ge(torch.mul(_15, dH), torch.add(inputHeight, padH))
1251     if _16:
1252       outputSize0 = _15
1253     else:
1254       outputSize0 = outputSize
1255     outputHeight = outputSize0
1256   else:
1257     outputHeight = outputSize
1258   if torch.ne(dW, 0):
1259     pass
1260   else:
1261     ops.prim.RaiseException(_4)
1262   _17 = torch.add(torch.add(inputWidth, padW), padW)
1263   _18 = torch.mul(dilationW, torch.sub(kW, 1))
1264   _19 = torch.sub(torch.sub(_17, _18), 1)
1265   if ceil_mode:
1266     _20 = torch.sub(dW, 1)
1267   else:
1268     _20 = 0
1269   _21 = torch.floordiv(torch.add(_19, _20), dW)
1270   outputSize1 = torch.add(_21, 1)
1271   if ceil_mode:
1272     _22 = torch.ge(torch.mul(_21, dW), torch.add(inputWidth, padW))
1273     if _22:
1274       outputSize2 = _21
1275     else:
1276       outputSize2 = outputSize1
1277     outputWidth = outputSize2
1278   else:
1279     outputWidth = outputSize1
1280   ndim = torch.len(input)
1281   if torch.gt(kW, 0):
1282     _23 = torch.gt(kH, 0)
1283   else:
1284     _23 = False
1285   if _23:
1286     pass
1287   else:
1288     ops.prim.RaiseException("AssertionError: ")
1289   if torch.gt(dW, 0):
1290     _24 = torch.gt(dH, 0)
1291   else:
1292     _24 = False
1293   if _24:
1294     pass
1295   else:
1296     ops.prim.RaiseException("AssertionError: ")
1297   if torch.gt(dilationH, 0):
1298     _25 = torch.gt(dilationW, 0)
1299   else:
1300     _25 = False
1301   if _25:
1302     pass
1303   else:
1304     ops.prim.RaiseException("AssertionError: ")
1305   if torch.ne(input[1], 0):
1306     valid_dims = torch.ne(input[2], 0)
1307   else:
1308     valid_dims = False
1309   if torch.eq(ndim, 3):
1310     _26 = torch.ne(input[0], 0)
1311   else:
1312     _26 = False
1313   if _26:
1314     _27 = valid_dims
1315   else:
1316     _27 = False
1317   if _27:
1318     _28 = True
1319   else:
1320     if torch.eq(ndim, 4):
1321       _29 = valid_dims
1322     else:
1323       _29 = False
1324     if _29:
1325       _30 = torch.ne(input[3], 0)
1326     else:
1327       _30 = False
1328     _28 = _30
1329   if _28:
1330     pass
1331   else:
1332     ops.prim.RaiseException("AssertionError: ")
1333   if torch.ge(torch.floordiv(kW, 2), padW):
1334     _32 = torch.ge(torch.floordiv(kH, 2), padH)
1335     _31 = _32
1336   else:
1337     _31 = False
1338   if _31:
1339     pass
1340   else:
1341     ops.prim.RaiseException("AssertionError: ")
1342   if torch.ge(outputWidth, 1):
1343     _33 = torch.ge(outputHeight, 1)
1344   else:
1345     _33 = False
1346   if _33:
1347     pass
1348   else:
1349     ops.prim.RaiseException("AssertionError: ")
1350   if torch.eq(torch.len(input), 3):
1351     _34 = [nInputPlane, outputHeight, outputWidth]
1352     out = _34
1353   else:
1354     _35 = [nbatch, nInputPlane, outputHeight, outputWidth]
1355     out = _35
1356   return (out, out)
1357 
1358 )=====")
1359 + std::string(R"=====(def t(self: List[int]) -> List[int]:
1360   if torch.le(torch.len(self), 2):
1361     pass
1362   else:
1363     ops.prim.RaiseException("AssertionError: ")
1364   self_len = torch.len(self)
1365   if torch.eq(self_len, 0):
1366     _0 = annotate(List[int], [])
1367   else:
1368     if torch.eq(self_len, 1):
1369       _1 = [self[0]]
1370     else:
1371       _1 = [self[1], self[0]]
1372     _0 = _1
1373   return _0
1374 
1375 def transpose(self: List[int],
1376     dim0: int,
1377     dim1: int) -> List[int]:
1378   ndims = torch.len(self)
1379   if torch.le(ndims, 0):
1380     dim_post_expr = 1
1381   else:
1382     dim_post_expr = ndims
1383   min = torch.neg(dim_post_expr)
1384   max = torch.sub(dim_post_expr, 1)
1385   if torch.lt(dim0, min):
1386     _0 = True
1387   else:
1388     _0 = torch.gt(dim0, max)
1389   if torch.__not__(_0):
1390     pass
1391   else:
1392     ops.prim.RaiseException("AssertionError: ")
1393   if torch.lt(dim0, 0):
1394     dim00 = torch.add(dim0, dim_post_expr)
1395   else:
1396     dim00 = dim0
1397   if torch.le(ndims, 0):
1398     dim_post_expr0 = 1
1399   else:
1400     dim_post_expr0 = ndims
1401   min0 = torch.neg(dim_post_expr0)
1402   max0 = torch.sub(dim_post_expr0, 1)
1403   if torch.lt(dim1, min0):
1404     _1 = True
1405   else:
1406     _1 = torch.gt(dim1, max0)
1407   if torch.__not__(_1):
1408     pass
1409   else:
1410     ops.prim.RaiseException("AssertionError: ")
1411   if torch.lt(dim1, 0):
1412     dim10 = torch.add(dim1, dim_post_expr0)
1413   else:
1414     dim10 = dim1
1415   if torch.eq(dim00, dim10):
1416     out = annotate(List[int], [])
1417     for _3 in range(torch.len(self)):
1418       elem = self[_3]
1419       _4 = torch.append(out, elem)
1420     _2 = out
1421   else:
1422     out0 = annotate(List[int], [])
1423     for i in range(ndims):
1424       if torch.eq(i, dim00):
1425         _5 = torch.append(out0, self[dim10])
1426       else:
1427         if torch.eq(i, dim10):
1428           _6 = torch.append(out0, self[dim00])
1429         else:
1430           _7 = torch.append(out0, self[i])
1431     _2 = out0
1432   return _2
1433 
1434 )=====")
1435 + std::string(R"=====(def conv1d(input: List[int],
1436     weight: List[int],
1437     bias: Optional[List[int]],
1438     stride: List[int],
1439     padding: List[int],
1440     dilation: List[int],
1441     groups: int) -> List[int]:
1442   if torch.eq(torch.len(weight), 3):
1443     pass
1444   else:
1445     ops.prim.RaiseException("AssertionError: ")
1446   if torch.eq(torch.len(input), 3):
1447     pass
1448   else:
1449     ops.prim.RaiseException("AssertionError: ")
1450   k = torch.len(input)
1451   weight_dim = torch.len(weight)
1452   non_negative = False
1453   for _0 in range(torch.len(padding)):
1454     val = padding[_0]
1455     if torch.lt(val, 0):
1456       non_negative0 = True
1457     else:
1458       non_negative0 = non_negative
1459     non_negative = non_negative0
1460   if torch.__not__(non_negative):
1461     pass
1462   else:
1463     ops.prim.RaiseException("AssertionError: ")
1464   non_negative1 = False
1465   for _1 in range(torch.len(stride)):
1466     val0 = stride[_1]
1467     if torch.lt(val0, 0):
1468       non_negative2 = True
1469     else:
1470       non_negative2 = non_negative1
1471     non_negative1 = non_negative2
1472   if torch.__not__(non_negative1):
1473     pass
1474   else:
1475     ops.prim.RaiseException("AssertionError: ")
1476   if torch.eq(weight_dim, k):
1477     pass
1478   else:
1479     ops.prim.RaiseException("AssertionError: ")
1480   if torch.ge(weight[0], groups):
1481     pass
1482   else:
1483     ops.prim.RaiseException("AssertionError: ")
1484   _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1485   if _2:
1486     pass
1487   else:
1488     ops.prim.RaiseException("AssertionError: ")
1489   _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1490   if _3:
1491     pass
1492   else:
1493     ops.prim.RaiseException("AssertionError: ")
1494   if torch.__is__(bias, None):
1495     _4 = True
1496   else:
1497     bias0 = unchecked_cast(List[int], bias)
1498     if torch.eq(torch.len(bias0), 1):
1499       _5 = torch.eq(bias0[0], weight[0])
1500     else:
1501       _5 = False
1502     _4 = _5
1503   if _4:
1504     pass
1505   else:
1506     ops.prim.RaiseException("AssertionError: ")
1507   for _6 in range(torch.__range_length(2, k, 1)):
1508     i = torch.__derive_index(_6, 2, 1)
1509     _7 = input[i]
1510     _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1511     _9 = torch.add(_7, _8)
1512     _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1513     if torch.ge(_9, torch.add(_10, 1)):
1514       pass
1515     else:
1516       ops.prim.RaiseException("AssertionError: ")
1517   has_dilation = torch.gt(torch.len(dilation), 0)
1518   dim = torch.len(input)
1519   output_size = annotate(List[int], [])
1520   _11 = torch.append(output_size, input[0])
1521   _12 = torch.append(output_size, weight[0])
1522   for _13 in range(torch.__range_length(2, dim, 1)):
1523     d = torch.__derive_index(_13, 2, 1)
1524     if has_dilation:
1525       dilation_ = dilation[torch.sub(d, 2)]
1526     else:
1527       dilation_ = 1
1528     _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1529     kernel = torch.add(_14, 1)
1530     _15 = input[d]
1531     _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1532     _17 = torch.sub(torch.add(_15, _16), kernel)
1533     _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1534     _19 = torch.append(output_size, torch.add(_18, 1))
1535   return output_size
1536 
1537 )=====")
1538 + std::string(R"=====(def conv2d(input: List[int],
1539     weight: List[int],
1540     bias: Optional[List[int]],
1541     stride: List[int],
1542     padding: List[int],
1543     dilation: List[int],
1544     groups: int) -> List[int]:
1545   if torch.eq(torch.len(weight), 4):
1546     pass
1547   else:
1548     ops.prim.RaiseException("AssertionError: ")
1549   if torch.eq(torch.len(input), 4):
1550     pass
1551   else:
1552     ops.prim.RaiseException("AssertionError: ")
1553   k = torch.len(input)
1554   weight_dim = torch.len(weight)
1555   non_negative = False
1556   for _0 in range(torch.len(padding)):
1557     val = padding[_0]
1558     if torch.lt(val, 0):
1559       non_negative0 = True
1560     else:
1561       non_negative0 = non_negative
1562     non_negative = non_negative0
1563   if torch.__not__(non_negative):
1564     pass
1565   else:
1566     ops.prim.RaiseException("AssertionError: ")
1567   non_negative1 = False
1568   for _1 in range(torch.len(stride)):
1569     val0 = stride[_1]
1570     if torch.lt(val0, 0):
1571       non_negative2 = True
1572     else:
1573       non_negative2 = non_negative1
1574     non_negative1 = non_negative2
1575   if torch.__not__(non_negative1):
1576     pass
1577   else:
1578     ops.prim.RaiseException("AssertionError: ")
1579   if torch.eq(weight_dim, k):
1580     pass
1581   else:
1582     ops.prim.RaiseException("AssertionError: ")
1583   if torch.ge(weight[0], groups):
1584     pass
1585   else:
1586     ops.prim.RaiseException("AssertionError: ")
1587   _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1588   if _2:
1589     pass
1590   else:
1591     ops.prim.RaiseException("AssertionError: ")
1592   _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1593   if _3:
1594     pass
1595   else:
1596     ops.prim.RaiseException("AssertionError: ")
1597   if torch.__is__(bias, None):
1598     _4 = True
1599   else:
1600     bias0 = unchecked_cast(List[int], bias)
1601     if torch.eq(torch.len(bias0), 1):
1602       _5 = torch.eq(bias0[0], weight[0])
1603     else:
1604       _5 = False
1605     _4 = _5
1606   if _4:
1607     pass
1608   else:
1609     ops.prim.RaiseException("AssertionError: ")
1610   for _6 in range(torch.__range_length(2, k, 1)):
1611     i = torch.__derive_index(_6, 2, 1)
1612     _7 = input[i]
1613     _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1614     _9 = torch.add(_7, _8)
1615     _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1616     if torch.ge(_9, torch.add(_10, 1)):
1617       pass
1618     else:
1619       ops.prim.RaiseException("AssertionError: ")
1620   has_dilation = torch.gt(torch.len(dilation), 0)
1621   dim = torch.len(input)
1622   output_size = annotate(List[int], [])
1623   _11 = torch.append(output_size, input[0])
1624   _12 = torch.append(output_size, weight[0])
1625   for _13 in range(torch.__range_length(2, dim, 1)):
1626     d = torch.__derive_index(_13, 2, 1)
1627     if has_dilation:
1628       dilation_ = dilation[torch.sub(d, 2)]
1629     else:
1630       dilation_ = 1
1631     _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1632     kernel = torch.add(_14, 1)
1633     _15 = input[d]
1634     _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1635     _17 = torch.sub(torch.add(_15, _16), kernel)
1636     _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1637     _19 = torch.append(output_size, torch.add(_18, 1))
1638   return output_size
1639 
1640 )=====")
1641 + std::string(R"=====(def batch_norm(input: List[int],
1642     weight: Optional[List[int]],
1643     bias: Optional[List[int]],
1644     running_mean: Optional[List[int]],
1645     running_var: Optional[List[int]],
1646     training: bool,
1647     momentum: float,
1648     eps: float,
1649     cudnn_enabled: bool) -> List[int]:
1650   out = annotate(List[int], [])
1651   for _0 in range(torch.len(input)):
1652     elem = input[_0]
1653     _1 = torch.append(out, elem)
1654   return out
1655 
1656 )=====")
1657 + std::string(R"=====(def conv3d(input: List[int],
1658     weight: List[int],
1659     bias: Optional[List[int]],
1660     stride: List[int],
1661     padding: List[int],
1662     dilation: List[int],
1663     groups: int) -> List[int]:
1664   if torch.eq(torch.len(weight), 5):
1665     pass
1666   else:
1667     ops.prim.RaiseException("AssertionError: ")
1668   if torch.eq(torch.len(input), 5):
1669     pass
1670   else:
1671     ops.prim.RaiseException("AssertionError: ")
1672   k = torch.len(input)
1673   weight_dim = torch.len(weight)
1674   non_negative = False
1675   for _0 in range(torch.len(padding)):
1676     val = padding[_0]
1677     if torch.lt(val, 0):
1678       non_negative0 = True
1679     else:
1680       non_negative0 = non_negative
1681     non_negative = non_negative0
1682   if torch.__not__(non_negative):
1683     pass
1684   else:
1685     ops.prim.RaiseException("AssertionError: ")
1686   non_negative1 = False
1687   for _1 in range(torch.len(stride)):
1688     val0 = stride[_1]
1689     if torch.lt(val0, 0):
1690       non_negative2 = True
1691     else:
1692       non_negative2 = non_negative1
1693     non_negative1 = non_negative2
1694   if torch.__not__(non_negative1):
1695     pass
1696   else:
1697     ops.prim.RaiseException("AssertionError: ")
1698   if torch.eq(weight_dim, k):
1699     pass
1700   else:
1701     ops.prim.RaiseException("AssertionError: ")
1702   if torch.ge(weight[0], groups):
1703     pass
1704   else:
1705     ops.prim.RaiseException("AssertionError: ")
1706   _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1707   if _2:
1708     pass
1709   else:
1710     ops.prim.RaiseException("AssertionError: ")
1711   _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1712   if _3:
1713     pass
1714   else:
1715     ops.prim.RaiseException("AssertionError: ")
1716   if torch.__is__(bias, None):
1717     _4 = True
1718   else:
1719     bias0 = unchecked_cast(List[int], bias)
1720     if torch.eq(torch.len(bias0), 1):
1721       _5 = torch.eq(bias0[0], weight[0])
1722     else:
1723       _5 = False
1724     _4 = _5
1725   if _4:
1726     pass
1727   else:
1728     ops.prim.RaiseException("AssertionError: ")
1729   for _6 in range(torch.__range_length(2, k, 1)):
1730     i = torch.__derive_index(_6, 2, 1)
1731     _7 = input[i]
1732     _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1733     _9 = torch.add(_7, _8)
1734     _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1735     if torch.ge(_9, torch.add(_10, 1)):
1736       pass
1737     else:
1738       ops.prim.RaiseException("AssertionError: ")
1739   has_dilation = torch.gt(torch.len(dilation), 0)
1740   dim = torch.len(input)
1741   output_size = annotate(List[int], [])
1742   _11 = torch.append(output_size, input[0])
1743   _12 = torch.append(output_size, weight[0])
1744   for _13 in range(torch.__range_length(2, dim, 1)):
1745     d = torch.__derive_index(_13, 2, 1)
1746     if has_dilation:
1747       dilation_ = dilation[torch.sub(d, 2)]
1748     else:
1749       dilation_ = 1
1750     _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1751     kernel = torch.add(_14, 1)
1752     _15 = input[d]
1753     _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1754     _17 = torch.sub(torch.add(_15, _16), kernel)
1755     _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1756     _19 = torch.append(output_size, torch.add(_18, 1))
1757   return output_size
1758 
1759 )=====")
1760 + std::string(R"=====(def conv_backwards(grad_output: List[int],
1761     input: List[int],
1762     weight: List[int],
1763     biases: Optional[List[int]]) -> Tuple[List[int], List[int], List[int]]:
1764   out = annotate(List[int], [])
1765   for _0 in range(torch.len(input)):
1766     elem = input[_0]
1767     _1 = torch.append(out, elem)
1768   out0 = annotate(List[int], [])
1769   for _2 in range(torch.len(weight)):
1770     elem0 = weight[_2]
1771     _3 = torch.append(out0, elem0)
1772   return (out, out0, [grad_output[1]])
1773 
1774 )=====")
1775 + std::string(R"=====(def conv_forwards(input: List[int],
1776     weight: List[int],
1777     bias: Optional[List[int]],
1778     stride: List[int],
1779     padding: List[int],
1780     dilation: List[int],
1781     transposed: bool,
1782     output_padding: List[int],
1783     groups: int) -> List[int]:
1784   has_dilation = torch.gt(torch.len(dilation), 0)
1785   has_output_padding = torch.gt(torch.len(output_padding), 0)
1786   dim = torch.len(input)
1787   output_size = annotate(List[int], [])
1788   if transposed:
1789     weight_output_channels_dim = 1
1790   else:
1791     weight_output_channels_dim = 0
1792   _0 = torch.append(output_size, input[0])
1793   if transposed:
1794     _1 = torch.mul(weight[weight_output_channels_dim], groups)
1795     _2 = torch.append(output_size, _1)
1796   else:
1797     _3 = torch.append(output_size, weight[weight_output_channels_dim])
1798   for _4 in range(torch.__range_length(2, dim, 1)):
1799     d = torch.__derive_index(_4, 2, 1)
1800     if has_dilation:
1801       dilation_ = dilation[torch.sub(d, 2)]
1802     else:
1803       dilation_ = 1
1804     if has_output_padding:
1805       output_padding_ = output_padding[torch.sub(d, 2)]
1806     else:
1807       output_padding_ = 0
1808     if transposed:
1809       kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1810       _5 = torch.mul(torch.sub(input[d], 1), stride[torch.sub(d, 2)])
1811       _6 = torch.mul(padding[torch.sub(d, 2)], 2)
1812       _7 = torch.add(torch.sub(_5, _6), kernel)
1813       _8 = torch.add(torch.add(_7, output_padding_), 1)
1814       _9 = torch.append(output_size, _8)
1815     else:
1816       _10 = torch.mul(dilation_, torch.sub(weight[d], 1))
1817       kernel0 = torch.add(_10, 1)
1818       _11 = input[d]
1819       _12 = torch.mul(padding[torch.sub(d, 2)], 2)
1820       _13 = torch.sub(torch.add(_11, _12), kernel0)
1821       _14 = torch.floordiv(_13, stride[torch.sub(d, 2)])
1822       _15 = torch.append(output_size, torch.add(_14, 1))
1823   return output_size
1824 
1825 )=====")
1826 + std::string(R"=====(def _conv_forwards(input: List[int],
1827     weight: List[int],
1828     bias: Optional[List[int]],
1829     stride: List[int],
1830     padding: List[int],
1831     dilation: List[int],
1832     transposed: bool,
1833     output_padding: List[int],
1834     groups: int,
1835     benchmark: bool,
1836     deterministic: bool,
1837     cudnn_enabled: bool,
1838     allow_tf32: bool) -> List[int]:
1839   has_dilation = torch.gt(torch.len(dilation), 0)
1840   has_output_padding = torch.gt(torch.len(output_padding), 0)
1841   dim = torch.len(input)
1842   output_size = annotate(List[int], [])
1843   if transposed:
1844     weight_output_channels_dim = 1
1845   else:
1846     weight_output_channels_dim = 0
1847   _0 = torch.append(output_size, input[0])
1848   if transposed:
1849     _1 = torch.mul(weight[weight_output_channels_dim], groups)
1850     _2 = torch.append(output_size, _1)
1851   else:
1852     _3 = torch.append(output_size, weight[weight_output_channels_dim])
1853   for _4 in range(torch.__range_length(2, dim, 1)):
1854     d = torch.__derive_index(_4, 2, 1)
1855     if has_dilation:
1856       dilation_ = dilation[torch.sub(d, 2)]
1857     else:
1858       dilation_ = 1
1859     if has_output_padding:
1860       output_padding_ = output_padding[torch.sub(d, 2)]
1861     else:
1862       output_padding_ = 0
1863     if transposed:
1864       kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1865       _5 = torch.mul(torch.sub(input[d], 1), stride[torch.sub(d, 2)])
1866       _6 = torch.mul(padding[torch.sub(d, 2)], 2)
1867       _7 = torch.add(torch.sub(_5, _6), kernel)
1868       _8 = torch.add(torch.add(_7, output_padding_), 1)
1869       _9 = torch.append(output_size, _8)
1870     else:
1871       _10 = torch.mul(dilation_, torch.sub(weight[d], 1))
1872       kernel0 = torch.add(_10, 1)
1873       _11 = input[d]
1874       _12 = torch.mul(padding[torch.sub(d, 2)], 2)
1875       _13 = torch.sub(torch.add(_11, _12), kernel0)
1876       _14 = torch.floordiv(_13, stride[torch.sub(d, 2)])
1877       _15 = torch.append(output_size, torch.add(_14, 1))
1878   return output_size
1879 
1880 )=====")
1881 + std::string(R"=====(def conv_transpose2d_input(input: List[int],
1882     weight: List[int],
1883     bias: Optional[List[int]]=None,
1884     stride: Optional[List[int]]=None,
1885     padding: Optional[List[int]]=None,
1886     output_padding: Optional[List[int]]=None,
1887     groups: int=1,
1888     dilation: Optional[List[int]]=None) -> List[int]:
1889   if torch.__is__(stride, None):
1890     stride0 = [1, 1]
1891   else:
1892     stride0 = unchecked_cast(List[int], stride)
1893   if torch.__is__(padding, None):
1894     padding0 = [0, 0]
1895   else:
1896     padding0 = unchecked_cast(List[int], padding)
1897   if torch.__is__(output_padding, None):
1898     output_padding0 = [0, 0]
1899   else:
1900     output_padding1 = unchecked_cast(List[int], output_padding)
1901     output_padding0 = output_padding1
1902   if torch.__is__(dilation, None):
1903     dilation0 = [1, 1]
1904   else:
1905     dilation0 = unchecked_cast(List[int], dilation)
1906   has_dilation = torch.gt(torch.len(dilation0), 0)
1907   dim = torch.len(input)
1908   output_size = annotate(List[int], [])
1909   _0 = torch.append(output_size, input[0])
1910   _1 = torch.append(output_size, torch.mul(weight[1], groups))
1911   for _2 in range(torch.__range_length(2, dim, 1)):
1912     d = torch.__derive_index(_2, 2, 1)
1913     if has_dilation:
1914       dilation_ = dilation0[torch.sub(d, 2)]
1915     else:
1916       dilation_ = 1
1917     kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1918     _3 = torch.mul(torch.sub(input[d], 1), stride0[torch.sub(d, 2)])
1919     _4 = torch.mul(padding0[torch.sub(d, 2)], 2)
1920     _5 = torch.add(torch.sub(_3, _4), kernel)
1921     _6 = torch.add(_5, output_padding0[torch.sub(d, 2)])
1922     _7 = torch.append(output_size, torch.add(_6, 1))
1923   return output_size
1924 
1925 )=====")
1926 + std::string(R"=====(def flatten(input: List[int],
1927     start_dim: int,
1928     end_dim: int) -> List[int]:
1929   _0 = torch.len(input)
1930   if torch.le(_0, 0):
1931     dim_post_expr = 1
1932   else:
1933     dim_post_expr = _0
1934   min = torch.neg(dim_post_expr)
1935   max = torch.sub(dim_post_expr, 1)
1936   if torch.lt(start_dim, min):
1937     _1 = True
1938   else:
1939     _1 = torch.gt(start_dim, max)
1940   if torch.__not__(_1):
1941     pass
1942   else:
1943     ops.prim.RaiseException("AssertionError: ")
1944   if torch.lt(start_dim, 0):
1945     start_dim0 = torch.add(start_dim, dim_post_expr)
1946   else:
1947     start_dim0 = start_dim
1948   _2 = torch.len(input)
1949   if torch.le(_2, 0):
1950     dim_post_expr0 = 1
1951   else:
1952     dim_post_expr0 = _2
1953   min0 = torch.neg(dim_post_expr0)
1954   max0 = torch.sub(dim_post_expr0, 1)
1955   if torch.lt(end_dim, min0):
1956     _3 = True
1957   else:
1958     _3 = torch.gt(end_dim, max0)
1959   if torch.__not__(_3):
1960     pass
1961   else:
1962     ops.prim.RaiseException("AssertionError: ")
1963   if torch.lt(end_dim, 0):
1964     end_dim0 = torch.add(end_dim, dim_post_expr0)
1965   else:
1966     end_dim0 = end_dim
1967   if torch.le(start_dim0, end_dim0):
1968     pass
1969   else:
1970     ops.prim.RaiseException("AssertionError: ")
1971   if torch.eq(torch.len(input), 0):
1972     _4 = [1]
1973   else:
1974     if torch.eq(start_dim0, end_dim0):
1975       out = annotate(List[int], [])
1976       for _6 in range(torch.len(input)):
1977         elem = input[_6]
1978         _7 = torch.append(out, elem)
1979       _5 = out
1980     else:
1981       _8 = torch.__range_length(start_dim0, torch.add(end_dim0, 1), 1)
1982       slice_numel = 1
1983       for _9 in range(_8):
1984         i = torch.__derive_index(_9, start_dim0, 1)
1985         slice_numel0 = torch.mul(slice_numel, input[i])
1986         slice_numel = slice_numel0
1987       shape = annotate(List[int], [])
1988       for i0 in range(start_dim0):
1989         _10 = torch.append(shape, input[i0])
1990       _11 = torch.append(shape, slice_numel)
1991       _12 = torch.add(end_dim0, 1)
1992       _13 = torch.__range_length(_12, torch.len(input), 1)
1993       for _14 in range(_13):
1994         i1 = torch.__derive_index(_14, _12, 1)
1995         _15 = torch.append(shape, input[i1])
1996       _5 = shape
1997     _4 = _5
1998   return _4
1999 
2000 )=====")
2001 + std::string(R"=====(def cat(tensors: List[List[int]],
2002     dim: int) -> List[int]:
2003   _0 = "AssertionError: Tensors must have same number of dimensions"
2004   _1 = "AssertionError: Sizes of tensors must match except in dimension"
2005   for _2 in range(torch.len(tensors)):
2006     tensor = tensors[_2]
2007     if torch.gt(torch.len(tensor), 0):
2008       pass
2009     else:
2010       ops.prim.RaiseException("AssertionError: ")
2011   out_dim: Optional[int] = None
2012   for _3 in range(torch.len(tensors)):
2013     size = tensors[_3]
2014     if torch.eq(torch.len(size), 1):
2015       _4 = torch.eq(size[0], 0)
2016     else:
2017       _4 = False
2018     if torch.__not__(_4):
2019       if torch.__is__(out_dim, None):
2020         _5 = torch.len(size)
2021         if torch.le(_5, 0):
2022           dim_post_expr = 1
2023         else:
2024           dim_post_expr = _5
2025         min = torch.neg(dim_post_expr)
2026         max = torch.sub(dim_post_expr, 1)
2027         if torch.lt(dim, min):
2028           _6 = True
2029         else:
2030           _6 = torch.gt(dim, max)
2031         if torch.__not__(_6):
2032           pass
2033         else:
2034           ops.prim.RaiseException("AssertionError: ")
2035         if torch.lt(dim, 0):
2036           out_dim2 = torch.add(dim, dim_post_expr)
2037         else:
2038           out_dim2 = dim
2039         out_dim1 = out_dim2
2040       else:
2041         out_dim1 = unchecked_cast(int, out_dim)
2042       out_dim0 : Optional[int] = out_dim1
2043     else:
2044       out_dim0 = out_dim
2045     out_dim = out_dim0
2046   if torch.__is__(out_dim, None):
2047     dim0 = dim
2048   else:
2049     dim0 = unchecked_cast(int, out_dim)
2050   if torch.gt(torch.len(tensors), 0):
2051     pass
2052   else:
2053     ops.prim.RaiseException("AssertionError: ")
2054   not_skipped_tensor: Optional[List[int]] = None
2055   for _7 in range(torch.len(tensors)):
2056     tensor0 = tensors[_7]
2057     numel = 1
2058     for _8 in range(torch.len(tensor0)):
2059       elem = tensor0[_8]
2060       numel = torch.mul(numel, elem)
2061     if torch.eq(numel, 0):
2062       _9 = torch.eq(torch.len(tensor0), 1)
2063     else:
2064       _9 = False
2065     if torch.__not__(_9):
2066       not_skipped_tensor0 : Optional[List[int]] = tensor0
2067     else:
2068       not_skipped_tensor0 = not_skipped_tensor
2069     not_skipped_tensor = not_skipped_tensor0
2070   _10 = torch.__is__(not_skipped_tensor, None)
2071   if _10:
2072     _11 = [0]
2073   else:
2074     not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor)
2075     cat_dim_size = 0
2076     for i in range(torch.len(tensors)):
2077       tensor1 = tensors[i]
2078       numel0 = 1
2079       for _12 in range(torch.len(tensor1)):
2080         elem0 = tensor1[_12]
2081         numel0 = torch.mul(numel0, elem0)
2082       if torch.eq(numel0, 0):
2083         _13 = torch.eq(torch.len(tensor1), 1)
2084       else:
2085         _13 = False
2086       if torch.__not__(_13):
2087         first_dims = torch.len(not_skipped_tensor1)
2088         second_dims = torch.len(tensor1)
2089         _14 = torch.eq(first_dims, second_dims)
2090         if _14:
2091           pass
2092         else:
2093           ops.prim.RaiseException(_0)
2094         _15 = torch.__range_length(0, first_dims, 1)
2095         for _16 in range(_15):
2096           dim1 = torch.__derive_index(_16, 0, 1)
2097           if torch.ne(dim1, dim0):
2098             _17 = torch.eq(not_skipped_tensor1[dim1], tensor1[dim1])
2099             if _17:
2100               pass
2101             else:
2102               ops.prim.RaiseException(_1)
2103           else:
2104             pass
2105         cat_dim_size1 = torch.add(cat_dim_size, tensor1[dim0])
2106         cat_dim_size0 = cat_dim_size1
2107       else:
2108         cat_dim_size0 = cat_dim_size
2109       cat_dim_size = cat_dim_size0
2110     result_size = annotate(List[int], [])
2111     for _18 in range(torch.len(not_skipped_tensor1)):
2112       elem1 = not_skipped_tensor1[_18]
2113       _19 = torch.append(result_size, elem1)
2114     _20 = torch._set_item(result_size, dim0, cat_dim_size)
2115     _11 = result_size
2116   return _11
2117 
2118 )=====")
2119 + std::string(R"=====(def stack(tensors: List[List[int]],
2120     dim: int) -> List[int]:
2121   _0 = "AssertionError: Tensors must have same number of dimensions"
2122   _1 = "AssertionError: Sizes of tensors must match except in dimension"
2123   unsqueezed_tensors = annotate(List[List[int]], [])
2124   for _2 in range(torch.len(tensors)):
2125     tensor = tensors[_2]
2126     _3 = torch.add(torch.len(tensor), 1)
2127     if torch.le(_3, 0):
2128       dim_post_expr = 1
2129     else:
2130       dim_post_expr = _3
2131     min = torch.neg(dim_post_expr)
2132     max = torch.sub(dim_post_expr, 1)
2133     if torch.lt(dim, min):
2134       _4 = True
2135     else:
2136       _4 = torch.gt(dim, max)
2137     if torch.__not__(_4):
2138       pass
2139     else:
2140       ops.prim.RaiseException("AssertionError: ")
2141     if torch.lt(dim, 0):
2142       dim0 = torch.add(dim, dim_post_expr)
2143     else:
2144       dim0 = dim
2145     unsqueezed = annotate(List[int], [])
2146     for _5 in range(torch.len(tensor)):
2147       elem = tensor[_5]
2148       _6 = torch.append(unsqueezed, elem)
2149     torch.insert(unsqueezed, dim0, 1)
2150     _7 = torch.append(unsqueezed_tensors, unsqueezed)
2151   for _8 in range(torch.len(unsqueezed_tensors)):
2152     tensor0 = unsqueezed_tensors[_8]
2153     if torch.gt(torch.len(tensor0), 0):
2154       pass
2155     else:
2156       ops.prim.RaiseException("AssertionError: ")
2157   out_dim: Optional[int] = None
2158   for _9 in range(torch.len(unsqueezed_tensors)):
2159     size = unsqueezed_tensors[_9]
2160     if torch.eq(torch.len(size), 1):
2161       _10 = torch.eq(size[0], 0)
2162     else:
2163       _10 = False
2164     if torch.__not__(_10):
2165       if torch.__is__(out_dim, None):
2166         _11 = torch.len(size)
2167         if torch.le(_11, 0):
2168           dim_post_expr0 = 1
2169         else:
2170           dim_post_expr0 = _11
2171         min0 = torch.neg(dim_post_expr0)
2172         max0 = torch.sub(dim_post_expr0, 1)
2173         if torch.lt(dim, min0):
2174           _12 = True
2175         else:
2176           _12 = torch.gt(dim, max0)
2177         if torch.__not__(_12):
2178           pass
2179         else:
2180           ops.prim.RaiseException("AssertionError: ")
2181         if torch.lt(dim, 0):
2182           dim1 = torch.add(dim, dim_post_expr0)
2183           out_dim2 = dim1
2184         else:
2185           out_dim2 = dim
2186         out_dim1 = out_dim2
2187       else:
2188         out_dim1 = unchecked_cast(int, out_dim)
2189       out_dim0 : Optional[int] = out_dim1
2190     else:
2191       out_dim0 = out_dim
2192     out_dim = out_dim0
2193   if torch.__is__(out_dim, None):
2194     dim2 = dim
2195   else:
2196     dim2 = unchecked_cast(int, out_dim)
2197   _13 = torch.gt(torch.len(unsqueezed_tensors), 0)
2198   if _13:
2199     pass
2200   else:
2201     ops.prim.RaiseException("AssertionError: ")
2202   not_skipped_tensor: Optional[List[int]] = None
2203   for _14 in range(torch.len(unsqueezed_tensors)):
2204     tensor1 = unsqueezed_tensors[_14]
2205     numel = 1
2206     for _15 in range(torch.len(tensor1)):
2207       elem0 = tensor1[_15]
2208       numel = torch.mul(numel, elem0)
2209     if torch.eq(numel, 0):
2210       _16 = torch.eq(torch.len(tensor1), 1)
2211     else:
2212       _16 = False
2213     if torch.__not__(_16):
2214       not_skipped_tensor0 : Optional[List[int]] = tensor1
2215     else:
2216       not_skipped_tensor0 = not_skipped_tensor
2217     not_skipped_tensor = not_skipped_tensor0
2218   _17 = torch.__is__(not_skipped_tensor, None)
2219   if _17:
2220     _18 = [0]
2221   else:
2222     not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor)
2223     cat_dim_size = 0
2224     for i in range(torch.len(unsqueezed_tensors)):
2225       tensor2 = unsqueezed_tensors[i]
2226       numel0 = 1
2227       for _19 in range(torch.len(tensor2)):
2228         elem1 = tensor2[_19]
2229         numel0 = torch.mul(numel0, elem1)
2230       if torch.eq(numel0, 0):
2231         _20 = torch.eq(torch.len(tensor2), 1)
2232       else:
2233         _20 = False
2234       if torch.__not__(_20):
2235         first_dims = torch.len(not_skipped_tensor1)
2236         second_dims = torch.len(tensor2)
2237         _21 = torch.eq(first_dims, second_dims)
2238         if _21:
2239           pass
2240         else:
2241           ops.prim.RaiseException(_0)
2242         _22 = torch.__range_length(0, first_dims, 1)
2243         for _23 in range(_22):
2244           dim3 = torch.__derive_index(_23, 0, 1)
2245           if torch.ne(dim3, dim2):
2246             _24 = torch.eq(not_skipped_tensor1[dim3], tensor2[dim3])
2247             if _24:
2248               pass
2249             else:
2250               ops.prim.RaiseException(_1)
2251           else:
2252             pass
2253         cat_dim_size1 = torch.add(cat_dim_size, tensor2[dim2])
2254         cat_dim_size0 = cat_dim_size1
2255       else:
2256         cat_dim_size0 = cat_dim_size
2257       cat_dim_size = cat_dim_size0
2258     result_size = annotate(List[int], [])
2259     for _25 in range(torch.len(not_skipped_tensor1)):
2260       elem2 = not_skipped_tensor1[_25]
2261       _26 = torch.append(result_size, elem2)
2262     _27 = torch._set_item(result_size, dim2, cat_dim_size)
2263     _18 = result_size
2264   return _18
2265 
2266 )=====")
2267 + std::string(R"=====(def permute(input: List[int],
2268     dims: List[int]) -> List[int]:
2269   _0 = torch.eq(torch.len(input), torch.len(dims))
2270   if _0:
2271     pass
2272   else:
2273     ops.prim.RaiseException("AssertionError: ")
2274   ndim = torch.len(dims)
2275   seen_dims = annotate(List[int], [])
2276   newSizes = annotate(List[int], [])
2277   for i in range(ndim):
2278     _1 = dims[i]
2279     if torch.le(ndim, 0):
2280       dim_post_expr = 1
2281     else:
2282       dim_post_expr = ndim
2283     min = torch.neg(dim_post_expr)
2284     max = torch.sub(dim_post_expr, 1)
2285     if torch.lt(_1, min):
2286       _2 = True
2287     else:
2288       _2 = torch.gt(_1, max)
2289     if torch.__not__(_2):
2290       pass
2291     else:
2292       ops.prim.RaiseException("AssertionError: ")
2293     if torch.lt(_1, 0):
2294       dim = torch.add(_1, dim_post_expr)
2295     else:
2296       dim = _1
2297     _3 = torch.append(seen_dims, dim)
2298     _4 = torch.append(newSizes, input[dim])
2299   for _5 in range(torch.__range_length(1, ndim, 1)):
2300     i0 = torch.__derive_index(_5, 1, 1)
2301     for j in range(i0):
2302       _6 = torch.ne(seen_dims[i0], seen_dims[j])
2303       if _6:
2304         pass
2305       else:
2306         ops.prim.RaiseException("AssertionError: ")
2307   return newSizes
2308 
2309 )=====")
2310 + std::string(R"=====(def movedim(self: List[int],
2311     source: List[int],
2312     destination: List[int]) -> List[int]:
2313   self_dim = torch.len(self)
2314   if torch.le(self_dim, 1):
2315     _0 = self
2316   else:
2317     normalized_src = annotate(List[int], [])
2318     normalized_dst = annotate(List[int], [])
2319     for i in range(torch.len(source)):
2320       _1 = source[i]
2321       if torch.le(self_dim, 0):
2322         dim_post_expr = 1
2323       else:
2324         dim_post_expr = self_dim
2325       min = torch.neg(dim_post_expr)
2326       max = torch.sub(dim_post_expr, 1)
2327       if torch.lt(_1, min):
2328         _2 = True
2329       else:
2330         _2 = torch.gt(_1, max)
2331       if torch.__not__(_2):
2332         pass
2333       else:
2334         ops.prim.RaiseException("AssertionError: ")
2335       if torch.lt(_1, 0):
2336         dim = torch.add(_1, dim_post_expr)
2337       else:
2338         dim = _1
2339       _3 = torch.append(normalized_src, dim)
2340       _4 = destination[i]
2341       if torch.le(self_dim, 0):
2342         dim_post_expr0 = 1
2343       else:
2344         dim_post_expr0 = self_dim
2345       min0 = torch.neg(dim_post_expr0)
2346       max0 = torch.sub(dim_post_expr0, 1)
2347       if torch.lt(_4, min0):
2348         _5 = True
2349       else:
2350         _5 = torch.gt(_4, max0)
2351       if torch.__not__(_5):
2352         pass
2353       else:
2354         ops.prim.RaiseException("AssertionError: ")
2355       if torch.lt(_4, 0):
2356         dim0 = torch.add(_4, dim_post_expr0)
2357       else:
2358         dim0 = _4
2359       _6 = torch.append(normalized_dst, dim0)
2360     order = annotate(List[int], [])
2361     for i0 in range(self_dim):
2362       _7 = torch.append(order, -1)
2363     src_dims = annotate(List[int], [])
2364     for i1 in range(self_dim):
2365       _8 = torch.append(src_dims, i1)
2366     dst_dims = annotate(List[int], [])
2367     for i2 in range(self_dim):
2368       _9 = torch.append(dst_dims, i2)
2369     for i3 in range(torch.len(source)):
2370       _10 = normalized_src[i3]
2371       _11 = torch._set_item(order, normalized_dst[i3], _10)
2372       _12 = torch._set_item(src_dims, normalized_src[i3], -1)
2373       _13 = torch._set_item(dst_dims, normalized_dst[i3], -1)
2374     source_dims = annotate(List[int], [])
2375     destination_dims = annotate(List[int], [])
2376     for _14 in range(torch.len(src_dims)):
2377       ele = src_dims[_14]
2378       if torch.ne(ele, -1):
2379         _15 = torch.append(source_dims, ele)
2380       else:
2381         pass
2382     for _16 in range(torch.len(dst_dims)):
2383       ele0 = dst_dims[_16]
2384       if torch.ne(ele0, -1):
2385         _17 = torch.append(destination_dims, ele0)
2386       else:
2387         pass
2388     rest_dim = torch.sub(self_dim, torch.len(source))
2389     for i4 in range(rest_dim):
2390       _18 = source_dims[i4]
2391       _19 = torch._set_item(order, destination_dims[i4], _18)
2392     _20 = torch.eq(torch.len(self), torch.len(order))
2393     if _20:
2394       pass
2395     else:
2396       ops.prim.RaiseException("AssertionError: ")
2397     ndim = torch.len(order)
2398     seen_dims = annotate(List[int], [])
2399     newSizes = annotate(List[int], [])
2400     for i5 in range(ndim):
2401       _21 = order[i5]
2402       if torch.le(ndim, 0):
2403         dim_post_expr1 = 1
2404       else:
2405         dim_post_expr1 = ndim
2406       min1 = torch.neg(dim_post_expr1)
2407       max1 = torch.sub(dim_post_expr1, 1)
2408       if torch.lt(_21, min1):
2409         _22 = True
2410       else:
2411         _22 = torch.gt(_21, max1)
2412       if torch.__not__(_22):
2413         pass
2414       else:
2415         ops.prim.RaiseException("AssertionError: ")
2416       if torch.lt(_21, 0):
2417         dim1 = torch.add(_21, dim_post_expr1)
2418       else:
2419         dim1 = _21
2420       _23 = torch.append(seen_dims, dim1)
2421       _24 = torch.append(newSizes, self[dim1])
2422     for _25 in range(torch.__range_length(1, ndim, 1)):
2423       i6 = torch.__derive_index(_25, 1, 1)
2424       for j in range(i6):
2425         _26 = torch.ne(seen_dims[i6], seen_dims[j])
2426         if _26:
2427           pass
2428         else:
2429           ops.prim.RaiseException("AssertionError: ")
2430     _0 = newSizes
2431   return _0
2432 
2433 )=====")
2434 + std::string(R"=====(def view(self: List[int],
2435     sizes: List[int]) -> List[int]:
2436   _0 = "AssertionError: only one dimension can be inferred"
2437   _1 = "AssertionError: invalid shape dimensions"
2438   numel = 1
2439   for _2 in range(torch.len(self)):
2440     elem = self[_2]
2441     numel = torch.mul(numel, elem)
2442   _3 = uninitialized(int)
2443   newsize = 1
2444   infer_dim: Optional[int] = None
2445   for dim in range(torch.len(sizes)):
2446     if torch.eq(sizes[dim], -1):
2447       if torch.__isnot__(infer_dim, None):
2448         ops.prim.RaiseException(_0)
2449       else:
2450         pass
2451       newsize0, infer_dim0 = newsize, dim
2452     else:
2453       if torch.ge(sizes[dim], 0):
2454         newsize1 = torch.mul(newsize, sizes[dim])
2455       else:
2456         ops.prim.RaiseException(_1)
2457         newsize1 = _3
2458       newsize0, infer_dim0 = newsize1, infer_dim
2459     newsize, infer_dim = newsize0, infer_dim0
2460   if torch.eq(numel, newsize):
2461     _4, infer_dim1 = True, infer_dim
2462   else:
2463     if torch.__isnot__(infer_dim, None):
2464       infer_dim3 = unchecked_cast(int, infer_dim)
2465       _5, infer_dim2 = torch.gt(newsize, 0), infer_dim3
2466     else:
2467       _5, infer_dim2 = False, infer_dim
2468     if _5:
2469       infer_dim5 = unchecked_cast(int, infer_dim2)
2470       _7 = torch.eq(torch.remainder(numel, newsize), 0)
2471       _6, infer_dim4 = _7, infer_dim5
2472     else:
2473       _6, infer_dim4 = False, infer_dim2
2474     _4, infer_dim1 = _6, infer_dim4
2475   if torch.__not__(_4):
2476     ops.prim.RaiseException("AssertionError: invalid shape")
2477   else:
2478     pass
2479   out = annotate(List[int], [])
2480   for _8 in range(torch.len(sizes)):
2481     elem0 = sizes[_8]
2482     _9 = torch.append(out, elem0)
2483   if torch.__isnot__(infer_dim1, None):
2484     infer_dim6 = unchecked_cast(int, infer_dim1)
2485     _10 = torch._set_item(out, infer_dim6, torch.floordiv(numel, newsize))
2486   else:
2487     pass
2488   return out
2489 
2490 )=====")
2491 + std::string(R"=====(def expand(self: List[int],
2492     sizes: List[int]) -> List[int]:
2493   _0 = torch.ge(torch.len(sizes), torch.len(self))
2494   if _0:
2495     pass
2496   else:
2497     ops.prim.RaiseException("AssertionError: ")
2498   ndim = torch.len(sizes)
2499   tensor_dim = torch.len(self)
2500   if torch.eq(ndim, 0):
2501     out = annotate(List[int], [])
2502     for _2 in range(torch.len(sizes)):
2503       elem = sizes[_2]
2504       _3 = torch.append(out, elem)
2505     _1 = out
2506   else:
2507     out0 = annotate(List[int], [])
2508     for i in range(ndim):
2509       offset = torch.sub(torch.sub(ndim, 1), i)
2510       dim = torch.sub(torch.sub(tensor_dim, 1), offset)
2511       if torch.ge(dim, 0):
2512         size = self[dim]
2513       else:
2514         size = 1
2515       targetSize = sizes[i]
2516       if torch.eq(targetSize, -1):
2517         if torch.ge(dim, 0):
2518           pass
2519         else:
2520           ops.prim.RaiseException("AssertionError: ")
2521         targetSize0 = size
2522       else:
2523         targetSize0 = targetSize
2524       if torch.ne(size, targetSize0):
2525         if torch.eq(size, 1):
2526           pass
2527         else:
2528           ops.prim.RaiseException("AssertionError: ")
2529         size0 = targetSize0
2530       else:
2531         size0 = size
2532       _4 = torch.append(out0, size0)
2533     _1 = out0
2534   return _1
2535 
2536 )=====")
2537 + std::string(R"=====(def expand_one_unused(self: List[int],
2538     sizes: List[int],
2539     inp0: Any) -> List[int]:
2540   _0 = torch.ge(torch.len(sizes), torch.len(self))
2541   if _0:
2542     pass
2543   else:
2544     ops.prim.RaiseException("AssertionError: ")
2545   ndim = torch.len(sizes)
2546   tensor_dim = torch.len(self)
2547   if torch.eq(ndim, 0):
2548     out = annotate(List[int], [])
2549     for _2 in range(torch.len(sizes)):
2550       elem = sizes[_2]
2551       _3 = torch.append(out, elem)
2552     _1 = out
2553   else:
2554     out0 = annotate(List[int], [])
2555     for i in range(ndim):
2556       offset = torch.sub(torch.sub(ndim, 1), i)
2557       dim = torch.sub(torch.sub(tensor_dim, 1), offset)
2558       if torch.ge(dim, 0):
2559         size = self[dim]
2560       else:
2561         size = 1
2562       targetSize = sizes[i]
2563       if torch.eq(targetSize, -1):
2564         if torch.ge(dim, 0):
2565           pass
2566         else:
2567           ops.prim.RaiseException("AssertionError: ")
2568         targetSize0 = size
2569       else:
2570         targetSize0 = targetSize
2571       if torch.ne(size, targetSize0):
2572         if torch.eq(size, 1):
2573           pass
2574         else:
2575           ops.prim.RaiseException("AssertionError: ")
2576         size0 = targetSize0
2577       else:
2578         size0 = size
2579       _4 = torch.append(out0, size0)
2580     _1 = out0
2581   return _1
2582 
2583 )=====")
2584 + std::string(R"=====(def sum_mean_dim(self: List[int],
2585     opt_dims: Optional[List[int]],
2586     keep_dim: bool,
2587     dt: Any) -> List[int]:
2588   out = annotate(List[int], [])
2589   if torch.__is__(opt_dims, None):
2590     _0, opt_dims0 = True, opt_dims
2591   else:
2592     opt_dims1 = unchecked_cast(List[int], opt_dims)
2593     _0, opt_dims0 = torch.eq(torch.len(opt_dims1), 0), opt_dims1
2594   if _0:
2595     _1 = torch.len(self)
2596     dims0 = annotate(List[int], [])
2597     for _2 in range(_1):
2598       _3 = torch.append(dims0, _2)
2599     dims = dims0
2600   else:
2601     opt_dims2 = unchecked_cast(List[int], opt_dims0)
2602     dims = opt_dims2
2603   for idx in range(torch.len(self)):
2604     is_mean_dim = False
2605     for _4 in range(torch.len(dims)):
2606       reduce_dim = dims[_4]
2607       _5 = torch.len(self)
2608       if torch.le(_5, 0):
2609         dim_post_expr = 1
2610       else:
2611         dim_post_expr = _5
2612       min = torch.neg(dim_post_expr)
2613       max = torch.sub(dim_post_expr, 1)
2614       if torch.lt(reduce_dim, min):
2615         _6 = True
2616       else:
2617         _6 = torch.gt(reduce_dim, max)
2618       if torch.__not__(_6):
2619         pass
2620       else:
2621         ops.prim.RaiseException("AssertionError: ")
2622       if torch.lt(reduce_dim, 0):
2623         dim0 = torch.add(reduce_dim, dim_post_expr)
2624         dim = dim0
2625       else:
2626         dim = reduce_dim
2627       if torch.eq(idx, dim):
2628         is_mean_dim0 = True
2629       else:
2630         is_mean_dim0 = is_mean_dim
2631       is_mean_dim = is_mean_dim0
2632     if is_mean_dim:
2633       if keep_dim:
2634         _7 = torch.append(out, 1)
2635       else:
2636         pass
2637     else:
2638       _8 = torch.append(out, self[idx])
2639   return out
2640 
2641 )=====")
2642 + std::string(R"=====(def max_dim(self: List[int],
2643     dim: int,
2644     keep_dim: bool) -> Tuple[List[int], List[int]]:
2645   _0 = [dim]
2646   out = annotate(List[int], [])
2647   for idx in range(torch.len(self)):
2648     is_mean_dim = False
2649     for _1 in range(torch.len(_0)):
2650       reduce_dim = _0[_1]
2651       _2 = torch.len(self)
2652       if torch.le(_2, 0):
2653         dim_post_expr = 1
2654       else:
2655         dim_post_expr = _2
2656       min = torch.neg(dim_post_expr)
2657       max = torch.sub(dim_post_expr, 1)
2658       if torch.lt(reduce_dim, min):
2659         _3 = True
2660       else:
2661         _3 = torch.gt(reduce_dim, max)
2662       if torch.__not__(_3):
2663         pass
2664       else:
2665         ops.prim.RaiseException("AssertionError: ")
2666       if torch.lt(reduce_dim, 0):
2667         dim1 = torch.add(reduce_dim, dim_post_expr)
2668         dim0 = dim1
2669       else:
2670         dim0 = reduce_dim
2671       if torch.eq(idx, dim0):
2672         is_mean_dim0 = True
2673       else:
2674         is_mean_dim0 = is_mean_dim
2675       is_mean_dim = is_mean_dim0
2676     if is_mean_dim:
2677       if keep_dim:
2678         _4 = torch.append(out, 1)
2679       else:
2680         pass
2681     else:
2682       _5 = torch.append(out, self[idx])
2683   return (out, out)
2684 
2685 )=====")
2686 + std::string(R"=====(def addmm(self: List[int],
2687     mat1: List[int],
2688     mat2: List[int],
2689     beta: Any,
2690     alpha: Any) -> List[int]:
2691   _0 = "AssertionError: self must be a matrix"
2692   _1 = "AssertionError: mat2 must be a matrix"
2693   _2 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
2694   if torch.eq(torch.len(mat1), 2):
2695     pass
2696   else:
2697     ops.prim.RaiseException(_0)
2698   if torch.eq(torch.len(mat2), 2):
2699     pass
2700   else:
2701     ops.prim.RaiseException(_1)
2702   if torch.eq(mat1[1], mat2[0]):
2703     pass
2704   else:
2705     ops.prim.RaiseException("AssertionError: ")
2706   _3 = [mat1[0], mat2[1]]
2707   dimsA = torch.len(self)
2708   ndim = ops.prim.max(dimsA, 2)
2709   expandedSizes = annotate(List[int], [])
2710   for i in range(ndim):
2711     offset = torch.sub(torch.sub(ndim, 1), i)
2712     dimA = torch.sub(torch.sub(dimsA, 1), offset)
2713     dimB = torch.sub(1, offset)
2714     if torch.ge(dimA, 0):
2715       sizeA = self[dimA]
2716     else:
2717       sizeA = 1
2718     if torch.ge(dimB, 0):
2719       sizeB = _3[dimB]
2720     else:
2721       sizeB = 1
2722     if torch.ne(sizeA, sizeB):
2723       _4 = torch.ne(sizeA, 1)
2724     else:
2725       _4 = False
2726     if _4:
2727       _5 = torch.ne(sizeB, 1)
2728     else:
2729       _5 = False
2730     if _5:
2731       _6 = torch.add("AssertionError: ", torch.format(_2, sizeA, sizeB, i))
2732       ops.prim.RaiseException(_6)
2733     else:
2734       pass
2735     if torch.eq(sizeA, 1):
2736       _7 = sizeB
2737     else:
2738       _7 = sizeA
2739     _8 = torch.append(expandedSizes, _7)
2740   return expandedSizes
2741 
2742 )=====")
2743 + std::string(R"=====(def upsample_nearest2d(input: List[int],
2744     output_size: Optional[List[int]],
2745     scale_factors: Optional[List[float]]) -> List[int]:
2746   _0 = "AssertionError: Either output_size or scale_factors must be presented"
2747   _1 = "AssertionError: Must specify exactly one of output_size and scale_factors"
2748   _2 = uninitialized(Optional[List[float]])
2749   out = annotate(List[int], [])
2750   _3 = torch.append(out, input[0])
2751   _4 = torch.append(out, input[1])
2752   if torch.__is__(scale_factors, None):
2753     _5, scale_factors0 = torch.__is__(output_size, None), scale_factors
2754   else:
2755     scale_factors1 = unchecked_cast(List[float], scale_factors)
2756     _5, scale_factors0 = False, scale_factors1
2757   if _5:
2758     ops.prim.RaiseException(_0)
2759   else:
2760     pass
2761   if torch.__isnot__(output_size, None):
2762     output_size1 = unchecked_cast(List[int], output_size)
2763     if torch.__is__(scale_factors0, None):
2764       scale_factors3 : Optional[List[float]] = scale_factors0
2765     else:
2766       ops.prim.RaiseException(_1)
2767       scale_factors3 = _2
2768     _6 = torch.eq(torch.len(output_size1), 2)
2769     if _6:
2770       pass
2771     else:
2772       ops.prim.RaiseException("AssertionError: ")
2773     _7 = torch.append(out, output_size1[0])
2774     _8 = torch.append(out, output_size1[1])
2775     scale_factors2, output_size0 = scale_factors3, output_size1
2776   else:
2777     scale_factors2, output_size0 = scale_factors0, output_size
2778   if torch.__isnot__(scale_factors2, None):
2779     scale_factors4 = unchecked_cast(List[float], scale_factors2)
2780     if torch.__is__(output_size0, None):
2781       pass
2782     else:
2783       ops.prim.RaiseException(_1)
2784     _9 = torch.eq(torch.len(scale_factors4), 2)
2785     if _9:
2786       pass
2787     else:
2788       ops.prim.RaiseException("AssertionError: ")
2789     _10 = torch.mul(input[2], scale_factors4[0])
2790     _11 = torch.append(out, int(_10))
2791     _12 = torch.mul(input[3], scale_factors4[1])
2792     _13 = torch.append(out, int(_12))
2793   else:
2794     pass
2795   return out
2796 
2797 )=====")
2798 + std::string(R"=====(def broadcast(a: List[int],
2799     b: List[int]) -> List[int]:
2800   _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
2801   dimsA = torch.len(a)
2802   dimsB = torch.len(b)
2803   ndim = ops.prim.max(dimsA, dimsB)
2804   expandedSizes = annotate(List[int], [])
2805   for i in range(ndim):
2806     offset = torch.sub(torch.sub(ndim, 1), i)
2807     dimA = torch.sub(torch.sub(dimsA, 1), offset)
2808     dimB = torch.sub(torch.sub(dimsB, 1), offset)
2809     if torch.ge(dimA, 0):
2810       sizeA = a[dimA]
2811     else:
2812       sizeA = 1
2813     if torch.ge(dimB, 0):
2814       sizeB = b[dimB]
2815     else:
2816       sizeB = 1
2817     if torch.ne(sizeA, sizeB):
2818       _1 = torch.ne(sizeA, 1)
2819     else:
2820       _1 = False
2821     if _1:
2822       _2 = torch.ne(sizeB, 1)
2823     else:
2824       _2 = False
2825     if _2:
2826       _3 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
2827       ops.prim.RaiseException(_3)
2828     else:
2829       pass
2830     if torch.eq(sizeA, 1):
2831       _4 = sizeB
2832     else:
2833       _4 = sizeA
2834     _5 = torch.append(expandedSizes, _4)
2835   return expandedSizes
2836 
2837 )=====")
2838 + std::string(R"=====(def argmax(self: List[int],
2839     dim: Optional[int]=None,
2840     keepdim: bool=False) -> List[int]:
2841   if torch.__is__(dim, None):
2842     _0 = annotate(List[int], [])
2843   else:
2844     dim0 = unchecked_cast(int, dim)
2845     _1 = torch.len(self)
2846     if torch.le(_1, 0):
2847       dim_post_expr = 1
2848     else:
2849       dim_post_expr = _1
2850     min = torch.neg(dim_post_expr)
2851     max = torch.sub(dim_post_expr, 1)
2852     if torch.lt(dim0, min):
2853       _2 = True
2854     else:
2855       _2 = torch.gt(dim0, max)
2856     if torch.__not__(_2):
2857       pass
2858     else:
2859       ops.prim.RaiseException("AssertionError: ")
2860     if torch.lt(dim0, 0):
2861       dim1 = torch.add(dim0, dim_post_expr)
2862     else:
2863       dim1 = dim0
2864     out = annotate(List[int], [])
2865     _3 = [9223372036854775807, torch.len(self)]
2866     for i in range(ops.prim.min(_3)):
2867       self_dim = self[i]
2868       if torch.eq(i, dim1):
2869         if keepdim:
2870           _4 = torch.append(out, 1)
2871         else:
2872           pass
2873       else:
2874         _5 = torch.append(out, self_dim)
2875     _0 = out
2876   return _0
2877 
2878 def bmm(self: List[int],
2879     mat2: List[int]) -> List[int]:
2880   _0 = "AssertionError: bmm only supports 3D tensors"
2881   _1 = "AssertionError: mismatching batch dimension"
2882   _2 = "AssertionError: mismatching contracting dimension"
2883   if torch.eq(torch.len(self), 3):
2884     pass
2885   else:
2886     ops.prim.RaiseException(_0)
2887   if torch.eq(torch.len(mat2), 3):
2888     pass
2889   else:
2890     ops.prim.RaiseException(_0)
2891   if torch.eq(self[0], mat2[0]):
2892     pass
2893   else:
2894     ops.prim.RaiseException(_1)
2895   if torch.eq(self[2], mat2[1]):
2896     pass
2897   else:
2898     ops.prim.RaiseException(_2)
2899   return [self[0], self[1], mat2[2]]
2900 
2901 def _shape_as_tensor(self: List[int]) -> List[int]:
2902   return [torch.len(self)]
2903 
2904 )=====")
2905 + std::string(R"=====(def topk(self: List[int],
2906     k: int,
2907     dim: int=-1) -> Tuple[List[int], List[int]]:
2908   _0 = "k ({}) is too big for dimension {} of size {}"
2909   if torch.eq(torch.len(self), 0):
2910     result = annotate(List[int], [])
2911   else:
2912     if torch.le(k, self[dim]):
2913       pass
2914     else:
2915       _1 = torch.format(_0, k, dim, self[dim])
2916       ops.prim.RaiseException(torch.add("AssertionError: ", _1))
2917     result0 = annotate(List[int], [])
2918     for _2 in range(torch.len(self)):
2919       elem = self[_2]
2920       _3 = torch.append(result0, elem)
2921     _4 = torch._set_item(result0, dim, k)
2922     result = result0
2923   return (result, result)
2924 
2925 def nll_loss_forward(self: List[int],
2926     target: List[int],
2927     weight: Optional[List[int]],
2928     reduction: int) -> Tuple[List[int], List[int]]:
2929   self_dim = torch.len(self)
2930   target_dim = torch.len(target)
2931   if torch.lt(0, self_dim):
2932     _0 = torch.le(self_dim, 2)
2933   else:
2934     _0 = False
2935   if _0:
2936     pass
2937   else:
2938     ops.prim.RaiseException("AssertionError: ")
2939   if torch.le(target_dim, 1):
2940     pass
2941   else:
2942     ops.prim.RaiseException("AssertionError: ")
2943   if torch.eq(self_dim, 1):
2944     no_batch_dim = torch.eq(target_dim, 0)
2945   else:
2946     no_batch_dim = False
2947   if no_batch_dim:
2948     _1 = True
2949   else:
2950     _1 = torch.eq(self[0], target[0])
2951   if _1:
2952     pass
2953   else:
2954     ops.prim.RaiseException("AssertionError: ")
2955   n_classes = self[-1]
2956   if torch.__is__(weight, None):
2957     _2 = True
2958   else:
2959     weight0 = unchecked_cast(List[int], weight)
2960     if torch.eq(torch.len(weight0), 1):
2961       _3 = torch.eq(weight0[0], n_classes)
2962     else:
2963       _3 = False
2964     _2 = _3
2965   if _2:
2966     pass
2967   else:
2968     ops.prim.RaiseException("AssertionError: ")
2969   if torch.eq(reduction, 0):
2970     _4 = torch.eq(self_dim, 2)
2971   else:
2972     _4 = False
2973   if _4:
2974     reduction_shape = [self[0]]
2975   else:
2976     reduction_shape = annotate(List[int], [])
2977   _5 = (reduction_shape, annotate(List[int], []))
2978   return _5
2979 
2980 )=====")
2981 + std::string(R"=====(def native_layer_norm(input: List[int],
2982     normalized_shape: List[int]) -> Tuple[List[int], List[int], List[int]]:
2983   reduction_shape = annotate(List[int], [])
2984   num_unreduced_dimensions = torch.sub(torch.len(input), torch.len(normalized_shape))
2985   if torch.ge(num_unreduced_dimensions, 0):
2986     pass
2987   else:
2988     ops.prim.RaiseException("AssertionError: ")
2989   for i in range(num_unreduced_dimensions):
2990     _0 = torch.append(reduction_shape, input[i])
2991   _1 = torch.__range_length(num_unreduced_dimensions, torch.len(input), 1)
2992   for _2 in range(_1):
2993     _3 = torch.append(reduction_shape, 1)
2994   out = annotate(List[int], [])
2995   for _4 in range(torch.len(input)):
2996     elem = input[_4]
2997     _5 = torch.append(out, elem)
2998   _6 = (out, reduction_shape, reduction_shape)
2999   return _6
3000 
3001 def native_batch_norm(input: List[int],
3002     weight: Optional[List[int]],
3003     bias: Optional[List[int]],
3004     running_mean: Optional[List[int]],
3005     running_var: Optional[List[int]],
3006     training: bool) -> Tuple[List[int], List[int], List[int]]:
3007   if training:
3008     _size = [input[1]]
3009   else:
3010     _size = [0]
3011   out = annotate(List[int], [])
3012   for _0 in range(torch.len(input)):
3013     elem = input[_0]
3014     _1 = torch.append(out, elem)
3015   return (out, _size, _size)
3016 
3017 def _batch_norm_with_update(input: List[int],
3018     weight: Optional[List[int]],
3019     bias: Optional[List[int]],
3020     running_mean: Optional[List[int]],
3021     running_var: Optional[List[int]]) -> Tuple[List[int], List[int], List[int], List[int]]:
3022   _size = [input[1]]
3023   out = annotate(List[int], [])
3024   for _0 in range(torch.len(input)):
3025     elem = input[_0]
3026     _1 = torch.append(out, elem)
3027   return (out, _size, _size, [0])
3028 
3029 )=====")
3030 + std::string(R"=====(def cross_entropy_loss(self: List[int],
3031     target: List[int],
3032     weight: Optional[List[int]]=None,
3033     reduction: int=1,
3034     ignore_index: int=-100,
3035     label_smoothing: float=0.) -> List[int]:
3036   self_dim = torch.len(self)
3037   target_dim = torch.len(target)
3038   if torch.lt(0, self_dim):
3039     _0 = torch.le(self_dim, 2)
3040   else:
3041     _0 = False
3042   if _0:
3043     pass
3044   else:
3045     ops.prim.RaiseException("AssertionError: ")
3046   if torch.le(target_dim, 1):
3047     pass
3048   else:
3049     ops.prim.RaiseException("AssertionError: ")
3050   if torch.eq(self_dim, 1):
3051     no_batch_dim = torch.eq(target_dim, 0)
3052   else:
3053     no_batch_dim = False
3054   if no_batch_dim:
3055     _1 = True
3056   else:
3057     _1 = torch.eq(self[0], target[0])
3058   if _1:
3059     pass
3060   else:
3061     ops.prim.RaiseException("AssertionError: ")
3062   n_classes = self[-1]
3063   if torch.__is__(weight, None):
3064     _2 = True
3065   else:
3066     weight0 = unchecked_cast(List[int], weight)
3067     if torch.eq(torch.len(weight0), 1):
3068       _3 = torch.eq(weight0[0], n_classes)
3069     else:
3070       _3 = False
3071     _2 = _3
3072   if _2:
3073     pass
3074   else:
3075     ops.prim.RaiseException("AssertionError: ")
3076   if torch.eq(reduction, 0):
3077     _4 = torch.eq(self_dim, 2)
3078   else:
3079     _4 = False
3080   if _4:
3081     reduction_shape = [self[0]]
3082   else:
3083     reduction_shape = annotate(List[int], [])
3084   _5 = (reduction_shape, annotate(List[int], []))
3085   return (_5)[0]
3086 
3087 )=====")
3088 + std::string(R"=====(def broadcast_three(a: List[int],
3089     b: List[int],
3090     c: List[int]) -> List[int]:
3091   _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3092   _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3093   dimsA = torch.len(a)
3094   dimsB = torch.len(b)
3095   ndim = ops.prim.max(dimsA, dimsB)
3096   expandedSizes = annotate(List[int], [])
3097   for i in range(ndim):
3098     offset = torch.sub(torch.sub(ndim, 1), i)
3099     dimA = torch.sub(torch.sub(dimsA, 1), offset)
3100     dimB = torch.sub(torch.sub(dimsB, 1), offset)
3101     if torch.ge(dimA, 0):
3102       sizeA = a[dimA]
3103     else:
3104       sizeA = 1
3105     if torch.ge(dimB, 0):
3106       sizeB = b[dimB]
3107     else:
3108       sizeB = 1
3109     if torch.ne(sizeA, sizeB):
3110       _2 = torch.ne(sizeA, 1)
3111     else:
3112       _2 = False
3113     if _2:
3114       _3 = torch.ne(sizeB, 1)
3115     else:
3116       _3 = False
3117     if _3:
3118       _4 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
3119       ops.prim.RaiseException(_4)
3120     else:
3121       pass
3122     if torch.eq(sizeA, 1):
3123       _5 = sizeB
3124     else:
3125       _5 = sizeA
3126     _6 = torch.append(expandedSizes, _5)
3127   dimsA0 = torch.len(expandedSizes)
3128   dimsB0 = torch.len(c)
3129   ndim0 = ops.prim.max(dimsA0, dimsB0)
3130   expandedSizes0 = annotate(List[int], [])
3131   for i0 in range(ndim0):
3132     offset0 = torch.sub(torch.sub(ndim0, 1), i0)
3133     dimA0 = torch.sub(torch.sub(dimsA0, 1), offset0)
3134     dimB0 = torch.sub(torch.sub(dimsB0, 1), offset0)
3135     if torch.ge(dimA0, 0):
3136       sizeA0 = expandedSizes[dimA0]
3137     else:
3138       sizeA0 = 1
3139     if torch.ge(dimB0, 0):
3140       sizeB0 = c[dimB0]
3141     else:
3142       sizeB0 = 1
3143     if torch.ne(sizeA0, sizeB0):
3144       _7 = torch.ne(sizeA0, 1)
3145     else:
3146       _7 = False
3147     if _7:
3148       _8 = torch.ne(sizeB0, 1)
3149     else:
3150       _8 = False
3151     if _8:
3152       _9 = torch.format(_1, sizeA0, sizeB0, i0)
3153       ops.prim.RaiseException(torch.add("AssertionError: ", _9))
3154     else:
3155       pass
3156     if torch.eq(sizeA0, 1):
3157       _10 = sizeB0
3158     else:
3159       _10 = sizeA0
3160     _11 = torch.append(expandedSizes0, _10)
3161   return expandedSizes0
3162 
3163 )=====")
3164 + std::string(R"=====(def broadcast_one_three(a: List[int],
3165     b: Any,
3166     c: List[int]) -> List[int]:
3167   _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3168   dimsA = torch.len(a)
3169   dimsB = torch.len(c)
3170   ndim = ops.prim.max(dimsA, dimsB)
3171   expandedSizes = annotate(List[int], [])
3172   for i in range(ndim):
3173     offset = torch.sub(torch.sub(ndim, 1), i)
3174     dimA = torch.sub(torch.sub(dimsA, 1), offset)
3175     dimB = torch.sub(torch.sub(dimsB, 1), offset)
3176     if torch.ge(dimA, 0):
3177       sizeA = a[dimA]
3178     else:
3179       sizeA = 1
3180     if torch.ge(dimB, 0):
3181       sizeB = c[dimB]
3182     else:
3183       sizeB = 1
3184     if torch.ne(sizeA, sizeB):
3185       _1 = torch.ne(sizeA, 1)
3186     else:
3187       _1 = False
3188     if _1:
3189       _2 = torch.ne(sizeB, 1)
3190     else:
3191       _2 = False
3192     if _2:
3193       _3 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
3194       ops.prim.RaiseException(_3)
3195     else:
3196       pass
3197     if torch.eq(sizeA, 1):
3198       _4 = sizeB
3199     else:
3200       _4 = sizeA
3201     _5 = torch.append(expandedSizes, _4)
3202   return expandedSizes
3203 
3204 )=====")
3205 + std::string(R"=====(def broadcast_inplace(a: List[int],
3206     b: List[int]) -> List[int]:
3207   _0 = "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) "
3208   _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3209   dimsA = torch.len(a)
3210   dimsB = torch.len(b)
3211   if torch.gt(dimsB, dimsA):
3212     _2 = torch.add("AssertionError: ", torch.format(_0, dimsB, dimsA))
3213     ops.prim.RaiseException(_2)
3214   else:
3215     pass
3216   for dimA in range(dimsA):
3217     dimB = torch.add(torch.sub(dimsB, dimsA), dimA)
3218     sizeA = a[dimA]
3219     if torch.ge(dimB, 0):
3220       sizeB = b[dimB]
3221     else:
3222       sizeB = 1
3223     if torch.ne(sizeA, sizeB):
3224       _3 = torch.ne(sizeB, 1)
3225     else:
3226       _3 = False
3227     if _3:
3228       _4 = torch.format(_1, sizeA, sizeB, dimA)
3229       ops.prim.RaiseException(torch.add("AssertionError: ", _4))
3230     else:
3231       pass
3232   out = annotate(List[int], [])
3233   for _5 in range(torch.len(a)):
3234     elem = a[_5]
3235     _6 = torch.append(out, elem)
3236   return out
3237 
3238 def nonzero_lower_bound(input: List[int]) -> List[int]:
3239   return [0, torch.len(input)]
3240 
3241 def nonzero_upper_bound(input: List[int]) -> List[int]:
3242   numel = 1
3243   for _0 in range(torch.len(input)):
3244     elem = input[_0]
3245     numel = torch.mul(numel, elem)
3246   return [numel, torch.len(input)]
3247 
3248 )=====")
3249 ;
3250 
3251 
GetSerializedShapeFunctions()3252 const std::string& GetSerializedShapeFunctions() {
3253   return shape_funcs;
3254 }
3255 
3256 
GetShapeFunctionMappings()3257 const OperatorMap<std::string>& GetShapeFunctionMappings() {
3258  static const OperatorMap<std::string> shape_mappings {
3259     {"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", "unary"},
3260     {"aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
3261     {"aten::dropout(Tensor input, float p, bool train) -> Tensor", "unary"},
3262     {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "adaptive_avg_pool2d"},
3263     {"prim::NumToTensor.Scalar(Scalar a) -> Tensor", "zero_dim_tensor"},
3264     {"prim::NumToTensor.bool(bool a) -> Tensor", "zero_dim_tensor"},
3265     {"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "unary"},
3266     {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "unary"},
3267     {"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "arange_end"},
3268     {"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start"},
3269     {"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start_step"},
3270     {"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"},
3271     {"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"},
3272     {"aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", "squeeze_dims"},
3273     {"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"},
3274     {"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", "slice"},
3275     {"aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", "select"},
3276     {"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "index_select"},
3277     {"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary"},
3278     {"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary"},
3279     {"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", "unary"},
3280     {"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "unary"},
3281     {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "embedding"},
3282     {"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"},
3283     {"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"},
3284     {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"},
3285     {"aten::matmul(Tensor self, Tensor other) -> Tensor", "matmul"},
3286     {"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "linear"},
3287     {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "max_pool2d"},
3288     {"aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "max_pool2d_with_indices"},
3289     {"aten::t(Tensor(a) self) -> Tensor(a)", "t"},
3290     {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "transpose"},
3291     {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", "conv1d"},
3292     {"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "conv2d"},
3293     {"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "batch_norm"},
3294     {"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
3295     {"aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "conv_backwards"},
3296     {"aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", "conv_forwards"},
3297     {"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "_conv_forwards"},
3298     {"aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", "conv_transpose2d_input"},
3299     {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
3300     {"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
3301     {"aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "stack"},
3302     {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
3303     {"aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "movedim"},
3304     {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
3305     {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
3306     {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
3307     {"aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"},
3308     {"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"},
3309     {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
3310     {"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
3311     {"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
3312     {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"},
3313     {"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", "upsample_nearest2d"},
3314     {"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", "unary"},
3315     {"aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", "unary"},
3316     {"aten::dequantize(Tensor self) -> Tensor", "unary"},
3317     {"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", "broadcast"},
3318     {"aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "argmax"},
3319     {"aten::bmm(Tensor self, Tensor mat2) -> Tensor", "bmm"},
3320     {"aten::_shape_as_tensor(Tensor self) -> Tensor", "_shape_as_tensor"},
3321     {"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", "topk"},
3322     {"aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", "nll_loss_forward"},
3323     {"aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "native_layer_norm"},
3324     {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3325     {"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3326     {"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3327     {"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "_batch_norm_with_update"},
3328     {"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
3329     {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
3330     {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
3331     {"aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "broadcast_inplace"},
3332   };
3333 
3334   return shape_mappings;
3335 }
3336 
GetBoundedShapeMappings()3337 const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() {
3338  static const OperatorMap<std::pair<std::string, std::string>> shape_mappings {
3339     {"aten::nonzero(Tensor self) -> (Tensor)", {"nonzero_lower_bound", "nonzero_upper_bound"}},
3340   };
3341 
3342   return shape_mappings;
3343 }
3344 
3345 // clang-format on
3346 
3347 } // namespace torch::jit
3348