1
2 /**
3 * @generated
4 * This is an auto-generated file. Please do not modify it by hand.
5 * To re-generate, please run:
6 * cd ~/pytorch && python
7 * torchgen/shape_functions/gen_jit_shape_functions.py
8 */
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/inliner.h>
11 #include <torch/csrc/jit/runtime/operator.h>
12 #include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
13
14 // clang-format off
15
16 namespace torch::jit {
17
18
19 std::string shape_funcs = ""
20 + std::string(R"=====(
21 def unary(self: List[int]) -> List[int]:
22 out = annotate(List[int], [])
23 for _0 in range(torch.len(self)):
24 elem = self[_0]
25 _1 = torch.append(out, elem)
26 return out
27
28 def adaptive_avg_pool2d(self: List[int],
29 out: List[int]) -> List[int]:
30 if torch.eq(torch.len(out), 2):
31 pass
32 else:
33 ops.prim.RaiseException("AssertionError: ")
34 if torch.eq(torch.len(self), 3):
35 _0 = True
36 else:
37 _0 = torch.eq(torch.len(self), 4)
38 if _0:
39 pass
40 else:
41 ops.prim.RaiseException("AssertionError: ")
42 _1 = torch.__range_length(1, torch.len(self), 1)
43 for _2 in range(_1):
44 i = torch.__derive_index(_2, 1, 1)
45 if torch.ne(self[i], 0):
46 pass
47 else:
48 ops.prim.RaiseException("AssertionError: ")
49 shape = annotate(List[int], [])
50 _3 = torch.__range_length(0, torch.sub(torch.len(self), 2), 1)
51 for _4 in range(_3):
52 i0 = torch.__derive_index(_4, 0, 1)
53 _5 = torch.append(shape, self[i0])
54 for _6 in range(torch.len(out)):
55 elem = out[_6]
56 _7 = torch.append(shape, elem)
57 return shape
58
59 def zero_dim_tensor(input: Any) -> List[int]:
60 return annotate(List[int], [])
61
62 def arange_end(end: Union[float, int],
63 inp0: Any,
64 inp1: Any,
65 inp2: Any,
66 inp3: Any) -> List[int]:
67 if torch.ge(end, 0):
68 pass
69 else:
70 ops.prim.RaiseException("AssertionError: ")
71 return [int(torch.ceil(end))]
72
73 def arange_start(start: Union[float, int],
74 end: Union[float, int],
75 inp0: Any,
76 inp1: Any,
77 inp2: Any,
78 inp3: Any) -> List[int]:
79 if torch.ge(end, 0):
80 pass
81 else:
82 ops.prim.RaiseException("AssertionError: ")
83 if torch.ge(end, start):
84 pass
85 else:
86 ops.prim.RaiseException("AssertionError: ")
87 _0 = int(torch.ceil(torch.sub(end, start)))
88 return [_0]
89
90 )=====")
91 + std::string(R"=====(def arange_start_step(start: Union[float, int],
92 end: Union[float, int],
93 step: Union[float, int],
94 inp0: Any,
95 inp1: Any,
96 inp2: Any,
97 inp3: Any) -> List[int]:
98 if torch.ne(step, 0):
99 pass
100 else:
101 ops.prim.RaiseException("AssertionError: ")
102 if torch.lt(step, 0):
103 if torch.ge(start, end):
104 pass
105 else:
106 ops.prim.RaiseException("AssertionError: ")
107 else:
108 if torch.ge(end, start):
109 pass
110 else:
111 ops.prim.RaiseException("AssertionError: ")
112 _0 = torch.div(torch.sub(end, start), step)
113 return [torch.ceil(_0)]
114
115 def squeeze_nodim(li: List[int]) -> List[int]:
116 out = annotate(List[int], [])
117 for i in range(torch.len(li)):
118 if torch.ne(li[i], 1):
119 _0 = torch.append(out, li[i])
120 else:
121 pass
122 return out
123
124 def squeeze(li: List[int],
125 dim: int) -> List[int]:
126 out = annotate(List[int], [])
127 _0 = torch.len(li)
128 if torch.le(_0, 0):
129 dim_post_expr = 1
130 else:
131 dim_post_expr = _0
132 min = torch.neg(dim_post_expr)
133 max = torch.sub(dim_post_expr, 1)
134 if torch.lt(dim, min):
135 _1 = True
136 else:
137 _1 = torch.gt(dim, max)
138 if torch.__not__(_1):
139 pass
140 else:
141 ops.prim.RaiseException("AssertionError: ")
142 if torch.lt(dim, 0):
143 wrapped_dim = torch.add(dim, dim_post_expr)
144 else:
145 wrapped_dim = dim
146 for i in range(torch.len(li)):
147 if torch.eq(i, wrapped_dim):
148 if torch.ne(li[i], 1):
149 _2 = torch.append(out, li[i])
150 else:
151 pass
152 else:
153 _3 = torch.append(out, li[i])
154 return out
155
156 )=====")
157 + std::string(R"=====(def squeeze_dims(li: List[int],
158 dims: List[int]) -> List[int]:
159 if torch.eq(torch.len(dims), 0):
160 _0 = li
161 else:
162 wrapped_dims = annotate(List[int], [])
163 for _1 in range(torch.len(dims)):
164 elem = dims[_1]
165 _2 = torch.append(wrapped_dims, elem)
166 for i in range(torch.len(dims)):
167 _3 = wrapped_dims[i]
168 _4 = torch.len(li)
169 if torch.le(_4, 0):
170 dim_post_expr = 1
171 else:
172 dim_post_expr = _4
173 min = torch.neg(dim_post_expr)
174 max = torch.sub(dim_post_expr, 1)
175 if torch.lt(_3, min):
176 _5 = True
177 else:
178 _5 = torch.gt(_3, max)
179 if torch.__not__(_5):
180 pass
181 else:
182 ops.prim.RaiseException("AssertionError: ")
183 if torch.lt(_3, 0):
184 dim = torch.add(_3, dim_post_expr)
185 else:
186 dim = _3
187 _6 = torch._set_item(wrapped_dims, i, dim)
188 result = annotate(List[int], [])
189 for i0 in range(torch.len(li)):
190 if torch.eq(li[i0], 1):
191 _7 = torch.__contains__(wrapped_dims, i0)
192 if torch.__not__(_7):
193 _8 = torch.append(result, li[i0])
194 else:
195 pass
196 else:
197 _9 = torch.append(result, li[i0])
198 _0 = result
199 return _0
200
201 def unsqueeze(li: List[int],
202 dim: int) -> List[int]:
203 _0 = torch.add(torch.len(li), 1)
204 if torch.le(_0, 0):
205 dim_post_expr = 1
206 else:
207 dim_post_expr = _0
208 min = torch.neg(dim_post_expr)
209 max = torch.sub(dim_post_expr, 1)
210 if torch.lt(dim, min):
211 _1 = True
212 else:
213 _1 = torch.gt(dim, max)
214 if torch.__not__(_1):
215 pass
216 else:
217 ops.prim.RaiseException("AssertionError: ")
218 if torch.lt(dim, 0):
219 dim0 = torch.add(dim, dim_post_expr)
220 else:
221 dim0 = dim
222 out = annotate(List[int], [])
223 for _2 in range(torch.len(li)):
224 elem = li[_2]
225 _3 = torch.append(out, elem)
226 torch.insert(out, dim0, 1)
227 return out
228
229 )=====")
230 + std::string(R"=====(def slice(self: List[int],
231 dim: int,
232 start: Optional[int],
233 end: Optional[int],
234 step: int) -> List[int]:
235 ndim = torch.len(self)
236 if torch.ne(ndim, 0):
237 pass
238 else:
239 ops.prim.RaiseException("AssertionError: ")
240 if torch.le(ndim, 0):
241 dim_post_expr = 1
242 else:
243 dim_post_expr = ndim
244 min = torch.neg(dim_post_expr)
245 max = torch.sub(dim_post_expr, 1)
246 if torch.lt(dim, min):
247 _0 = True
248 else:
249 _0 = torch.gt(dim, max)
250 if torch.__not__(_0):
251 pass
252 else:
253 ops.prim.RaiseException("AssertionError: ")
254 if torch.lt(dim, 0):
255 dim0 = torch.add(dim, dim_post_expr)
256 else:
257 dim0 = dim
258 if torch.__isnot__(start, None):
259 start_val = unchecked_cast(int, start)
260 else:
261 start_val = 0
262 if torch.__isnot__(end, None):
263 end_val = unchecked_cast(int, end)
264 else:
265 end_val = 9223372036854775807
266 if torch.gt(step, 0):
267 pass
268 else:
269 ops.prim.RaiseException("AssertionError: ")
270 _1 = torch.eq(start_val, 9223372036854775807)
271 if _1:
272 start_val0 = 0
273 else:
274 start_val0 = start_val
275 if torch.lt(start_val0, 0):
276 start_val1 = torch.add(start_val0, self[dim0])
277 else:
278 start_val1 = start_val0
279 if torch.lt(end_val, 0):
280 end_val0 = torch.add(end_val, self[dim0])
281 else:
282 end_val0 = end_val
283 if torch.lt(start_val1, 0):
284 start_val2 = 0
285 else:
286 if torch.gt(start_val1, self[dim0]):
287 start_val3 = self[dim0]
288 else:
289 start_val3 = start_val1
290 start_val2 = start_val3
291 if torch.lt(end_val0, start_val2):
292 end_val1 = start_val2
293 else:
294 if torch.ge(end_val0, self[dim0]):
295 end_val2 = self[dim0]
296 else:
297 end_val2 = end_val0
298 end_val1 = end_val2
299 slice_len = torch.sub(end_val1, start_val2)
300 out = annotate(List[int], [])
301 for _2 in range(torch.len(self)):
302 elem = self[_2]
303 _3 = torch.append(out, elem)
304 _4 = torch.sub(torch.add(slice_len, step), 1)
305 _5 = torch._set_item(out, dim0, torch.floordiv(_4, step))
306 return out
307
308 )=====")
309 + std::string(R"=====(def select(self: List[int],
310 dim: int,
311 index: int) -> List[int]:
312 ndim = torch.len(self)
313 if torch.ne(ndim, 0):
314 pass
315 else:
316 ops.prim.RaiseException("AssertionError: ")
317 if torch.le(ndim, 0):
318 dim_post_expr = 1
319 else:
320 dim_post_expr = ndim
321 min = torch.neg(dim_post_expr)
322 max = torch.sub(dim_post_expr, 1)
323 if torch.lt(dim, min):
324 _0 = True
325 else:
326 _0 = torch.gt(dim, max)
327 if torch.__not__(_0):
328 pass
329 else:
330 ops.prim.RaiseException("AssertionError: ")
331 if torch.lt(dim, 0):
332 dim0 = torch.add(dim, dim_post_expr)
333 else:
334 dim0 = dim
335 size = self[dim0]
336 if torch.lt(index, torch.neg(size)):
337 _1 = True
338 else:
339 _1 = torch.ge(index, size)
340 if torch.__not__(_1):
341 pass
342 else:
343 ops.prim.RaiseException("AssertionError: ")
344 out = annotate(List[int], [])
345 for i in range(ndim):
346 if torch.ne(i, dim0):
347 _2 = torch.append(out, self[i])
348 else:
349 pass
350 return out
351
352 )=====")
353 + std::string(R"=====(def index_select(self: List[int],
354 dim: int,
355 index: List[int]) -> List[int]:
356 _0 = torch.len(self)
357 if torch.le(_0, 0):
358 dim_post_expr = 1
359 else:
360 dim_post_expr = _0
361 min = torch.neg(dim_post_expr)
362 max = torch.sub(dim_post_expr, 1)
363 if torch.lt(dim, min):
364 _1 = True
365 else:
366 _1 = torch.gt(dim, max)
367 if torch.__not__(_1):
368 pass
369 else:
370 ops.prim.RaiseException("AssertionError: ")
371 if torch.lt(dim, 0):
372 dim0 = torch.add(dim, dim_post_expr)
373 else:
374 dim0 = dim
375 numel = 1
376 for _2 in range(torch.len(index)):
377 elem = index[_2]
378 numel = torch.mul(numel, elem)
379 if torch.le(torch.len(index), 1):
380 pass
381 else:
382 ops.prim.RaiseException("AssertionError: ")
383 if torch.eq(dim0, 0):
384 _3 = True
385 else:
386 _3 = torch.lt(dim0, torch.len(self))
387 if _3:
388 pass
389 else:
390 ops.prim.RaiseException("AssertionError: ")
391 result_size = annotate(List[int], [])
392 for i in range(torch.len(self)):
393 if torch.eq(dim0, i):
394 _4 = torch.append(result_size, numel)
395 else:
396 _5 = torch.append(result_size, self[i])
397 return result_size
398
399 )=====")
400 + std::string(R"=====(def embedding(weight: List[int],
401 indices: List[int],
402 padding_idx: int=-1,
403 scale_grad_by_freq: bool=False,
404 sparse: bool=False) -> List[int]:
405 if torch.eq(torch.len(weight), 2):
406 pass
407 else:
408 ops.prim.RaiseException("AssertionError: ")
409 if torch.eq(torch.len(indices), 1):
410 _1 = torch.len(weight)
411 if torch.le(_1, 0):
412 dim_post_expr = 1
413 else:
414 dim_post_expr = _1
415 min = torch.neg(dim_post_expr)
416 max = torch.sub(dim_post_expr, 1)
417 if torch.lt(0, min):
418 _2 = True
419 else:
420 _2 = torch.gt(0, max)
421 if torch.__not__(_2):
422 pass
423 else:
424 ops.prim.RaiseException("AssertionError: ")
425 numel = 1
426 for _3 in range(torch.len(indices)):
427 elem = indices[_3]
428 numel = torch.mul(numel, elem)
429 if torch.le(torch.len(indices), 1):
430 pass
431 else:
432 ops.prim.RaiseException("AssertionError: ")
433 result_size = annotate(List[int], [])
434 for i in range(torch.len(weight)):
435 if torch.eq(0, i):
436 _4 = torch.append(result_size, numel)
437 else:
438 _5 = torch.append(result_size, weight[i])
439 _0 = result_size
440 else:
441 size = annotate(List[int], [])
442 for _6 in range(torch.len(indices)):
443 elem0 = indices[_6]
444 _7 = torch.append(size, elem0)
445 _8 = torch.append(size, weight[1])
446 _0 = size
447 return _0
448
449 def mm(self: List[int],
450 mat2: List[int]) -> List[int]:
451 _0 = "AssertionError: self must be a matrix"
452 _1 = "AssertionError: mat2 must be a matrix"
453 if torch.eq(torch.len(self), 2):
454 pass
455 else:
456 ops.prim.RaiseException(_0)
457 if torch.eq(torch.len(mat2), 2):
458 pass
459 else:
460 ops.prim.RaiseException(_1)
461 if torch.eq(self[1], mat2[0]):
462 pass
463 else:
464 ops.prim.RaiseException("AssertionError: ")
465 return [self[0], mat2[1]]
466
467 )=====")
468 + std::string(R"=====(def dot(self: List[int],
469 tensor: List[int]) -> List[int]:
470 if torch.eq(torch.len(self), 1):
471 _0 = torch.eq(torch.len(tensor), 1)
472 else:
473 _0 = False
474 if _0:
475 pass
476 else:
477 ops.prim.RaiseException("AssertionError: ")
478 if torch.eq(self[0], tensor[0]):
479 pass
480 else:
481 ops.prim.RaiseException("AssertionError: ")
482 return annotate(List[int], [])
483
484 def mv(self: List[int],
485 vec: List[int]) -> List[int]:
486 if torch.eq(torch.len(self), 2):
487 _0 = torch.eq(torch.len(vec), 1)
488 else:
489 _0 = False
490 if _0:
491 pass
492 else:
493 ops.prim.RaiseException("AssertionError: ")
494 if torch.eq(self[1], vec[0]):
495 pass
496 else:
497 ops.prim.RaiseException("AssertionError: ")
498 return [self[0]]
499
500 )=====")
501 + std::string(R"=====(def matmul(tensor1: List[int],
502 tensor2: List[int]) -> List[int]:
503 _0 = "AssertionError: self must be a matrix"
504 _1 = "AssertionError: mat2 must be a matrix"
505 _2 = "AssertionError: self must be a matrix"
506 _3 = "AssertionError: mat2 must be a matrix"
507 _4 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
508 _5 = "AssertionError: both arguments to matmul need to be at least 1D"
509 _6 = uninitialized(List[int])
510 dim_tensor1 = torch.len(tensor1)
511 dim_tensor2 = torch.len(tensor2)
512 if torch.eq(dim_tensor1, 1):
513 _7 = torch.eq(dim_tensor2, 1)
514 else:
515 _7 = False
516 if _7:
517 if torch.eq(torch.len(tensor1), 1):
518 _9 = torch.eq(torch.len(tensor2), 1)
519 else:
520 _9 = False
521 if _9:
522 pass
523 else:
524 ops.prim.RaiseException("AssertionError: ")
525 if torch.eq(tensor1[0], tensor2[0]):
526 pass
527 else:
528 ops.prim.RaiseException("AssertionError: ")
529 _8 = annotate(List[int], [])
530 else:
531 if torch.eq(dim_tensor1, 2):
532 _10 = torch.eq(dim_tensor2, 1)
533 else:
534 _10 = False
535 if _10:
536 if torch.eq(torch.len(tensor1), 2):
537 _12 = torch.eq(torch.len(tensor2), 1)
538 else:
539 _12 = False
540 if _12:
541 pass
542 else:
543 ops.prim.RaiseException("AssertionError: ")
544 if torch.eq(tensor1[1], tensor2[0]):
545 pass
546 else:
547 ops.prim.RaiseException("AssertionError: ")
548 _11 = [tensor1[0]]
549 else:
550 if torch.eq(dim_tensor1, 1):
551 _13 = torch.eq(dim_tensor2, 2)
552 else:
553 _13 = False
554 if _13:
555 _15 = torch.add(torch.len(tensor1), 1)
556 if torch.le(_15, 0):
557 dim_post_expr = 1
558 else:
559 dim_post_expr = _15
560 min = torch.neg(dim_post_expr)
561 max = torch.sub(dim_post_expr, 1)
562 if torch.lt(0, min):
563 _16 = True
564 else:
565 _16 = torch.gt(0, max)
566 if torch.__not__(_16):
567 pass
568 else:
569 ops.prim.RaiseException("AssertionError: ")
570 out = annotate(List[int], [])
571 for _17 in range(torch.len(tensor1)):
572 elem = tensor1[_17]
573 _18 = torch.append(out, elem)
574 torch.insert(out, 0, 1)
575 if torch.eq(torch.len(out), 2):
576 pass
577 else:
578 ops.prim.RaiseException(_0)
579 if torch.eq(torch.len(tensor2), 2):
580 pass
581 else:
582 ops.prim.RaiseException(_1)
583 if torch.eq(out[1], tensor2[0]):
584 pass
585 else:
586 ops.prim.RaiseException("AssertionError: ")
587 _19 = [out[0], tensor2[1]]
588 out0 = annotate(List[int], [])
589 for i in range(2):
590 if torch.eq(i, 0):
591 if torch.ne(_19[i], 1):
592 _20 = torch.append(out0, _19[i])
593 else:
594 pass
595 else:
596 _21 = torch.append(out0, _19[i])
597 _14 = out0
598 else:
599 if torch.eq(dim_tensor1, 2):
600 _22 = torch.eq(dim_tensor2, 2)
601 else:
602 _22 = False
603 if _22:
604 _24 = torch.eq(torch.len(tensor1), 2)
605 if _24:
606 pass
607 else:
608 ops.prim.RaiseException(_2)
609 _25 = torch.eq(torch.len(tensor2), 2)
610 if _25:
611 pass
612 else:
613 ops.prim.RaiseException(_3)
614 _26 = torch.eq(tensor1[1], tensor2[0])
615 if _26:
616 pass
617 else:
618 ops.prim.RaiseException("AssertionError: ")
619 _23 = [tensor1[0], tensor2[1]]
620 else:
621 if torch.ge(dim_tensor1, 1):
622 _27 = torch.ge(dim_tensor2, 1)
623 else:
624 _27 = False
625 if _27:
626 if torch.gt(dim_tensor1, 1):
627 n = tensor1[-2]
628 else:
629 n = 1
630 batch_tensor1 = annotate(List[int], [])
631 for i0 in range(torch.sub(dim_tensor1, 2)):
632 _29 = torch.append(batch_tensor1, tensor1[i0])
633 p = tensor2[-1]
634 batch_tensor2 = annotate(List[int], [])
635 for i1 in range(torch.sub(dim_tensor2, 2)):
636 _30 = torch.append(batch_tensor2, tensor2[i1])
637 dimsA = torch.len(batch_tensor1)
638 dimsB = torch.len(batch_tensor2)
639 ndim = ops.prim.max(dimsA, dimsB)
640 expand_batch_portion = annotate(List[int], [])
641 for i2 in range(ndim):
642 offset = torch.sub(torch.sub(ndim, 1), i2)
643 dimA = torch.sub(torch.sub(dimsA, 1), offset)
644 dimB = torch.sub(torch.sub(dimsB, 1), offset)
645 if torch.ge(dimA, 0):
646 sizeA = batch_tensor1[dimA]
647 else:
648 sizeA = 1
649 if torch.ge(dimB, 0):
650 sizeB = batch_tensor2[dimB]
651 else:
652 sizeB = 1
653 if torch.ne(sizeA, sizeB):
654 _31 = torch.ne(sizeA, 1)
655 else:
656 _31 = False
657 if _31:
658 _32 = torch.ne(sizeB, 1)
659 else:
660 _32 = False
661 if _32:
662 _33 = torch.format(_4, sizeA, sizeB, i2)
663 _34 = torch.add("AssertionError: ", _33)
664 ops.prim.RaiseException(_34)
665 else:
666 pass
667 if torch.eq(sizeA, 1):
668 _35 = sizeB
669 else:
670 _35 = sizeA
671 _36 = torch.append(expand_batch_portion, _35)
672 if torch.gt(dim_tensor1, 1):
673 _37 = torch.append(expand_batch_portion, n)
674 else:
675 pass
676 if torch.gt(dim_tensor2, 1):
677 _38 = torch.append(expand_batch_portion, p)
678 else:
679 pass
680 _28 = expand_batch_portion
681 else:
682 ops.prim.RaiseException(_5)
683 _28 = _6
684 _23 = _28
685 _14 = _23
686 _11 = _14
687 _8 = _11
688 return _8
689
690 )=====")
691 + std::string(R"=====(def linear(input: List[int],
692 weight: List[int],
693 bias: Optional[List[int]]) -> List[int]:
694 _0 = "AssertionError: self must be a matrix"
695 _1 = "AssertionError: mat2 must be a matrix"
696 _2 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
697 _3 = "AssertionError: both arguments to matmul need to be at least 1D"
698 _4 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
699 if torch.le(torch.len(weight), 2):
700 pass
701 else:
702 ops.prim.RaiseException("AssertionError: ")
703 self_len = torch.len(weight)
704 if torch.eq(self_len, 0):
705 _5 = annotate(List[int], [])
706 else:
707 if torch.eq(self_len, 1):
708 _6 = [weight[0]]
709 else:
710 _6 = [weight[1], weight[0]]
711 _5 = _6
712 _7 = uninitialized(List[int])
713 dim_tensor1 = torch.len(input)
714 dim_tensor2 = torch.len(_5)
715 if torch.eq(dim_tensor1, 1):
716 _8 = torch.eq(dim_tensor2, 1)
717 else:
718 _8 = False
719 if _8:
720 if torch.eq(torch.len(input), 1):
721 _9 = torch.eq(torch.len(_5), 1)
722 else:
723 _9 = False
724 if _9:
725 pass
726 else:
727 ops.prim.RaiseException("AssertionError: ")
728 if torch.eq(input[0], _5[0]):
729 pass
730 else:
731 ops.prim.RaiseException("AssertionError: ")
732 out = annotate(List[int], [])
733 else:
734 if torch.eq(dim_tensor1, 2):
735 _10 = torch.eq(dim_tensor2, 1)
736 else:
737 _10 = False
738 if _10:
739 if torch.eq(torch.len(input), 2):
740 _12 = torch.eq(torch.len(_5), 1)
741 else:
742 _12 = False
743 if _12:
744 pass
745 else:
746 ops.prim.RaiseException("AssertionError: ")
747 if torch.eq(input[1], _5[0]):
748 pass
749 else:
750 ops.prim.RaiseException("AssertionError: ")
751 _11 = [input[0]]
752 else:
753 if torch.eq(dim_tensor1, 1):
754 _13 = torch.eq(dim_tensor2, 2)
755 else:
756 _13 = False
757 if _13:
758 _15 = torch.add(torch.len(input), 1)
759 if torch.le(_15, 0):
760 dim_post_expr = 1
761 else:
762 dim_post_expr = _15
763 min = torch.neg(dim_post_expr)
764 max = torch.sub(dim_post_expr, 1)
765 if torch.lt(0, min):
766 _16 = True
767 else:
768 _16 = torch.gt(0, max)
769 if torch.__not__(_16):
770 pass
771 else:
772 ops.prim.RaiseException("AssertionError: ")
773 out0 = annotate(List[int], [])
774 for _17 in range(torch.len(input)):
775 elem = input[_17]
776 _18 = torch.append(out0, elem)
777 torch.insert(out0, 0, 1)
778 if torch.eq(torch.len(out0), 2):
779 pass
780 else:
781 ops.prim.RaiseException(_0)
782 if torch.eq(torch.len(_5), 2):
783 pass
784 else:
785 ops.prim.RaiseException(_1)
786 if torch.eq(out0[1], _5[0]):
787 pass
788 else:
789 ops.prim.RaiseException("AssertionError: ")
790 _19 = [out0[0], _5[1]]
791 out1 = annotate(List[int], [])
792 for i in range(2):
793 if torch.eq(i, 0):
794 if torch.ne(_19[i], 1):
795 _20 = torch.append(out1, _19[i])
796 else:
797 pass
798 else:
799 _21 = torch.append(out1, _19[i])
800 _14 = out1
801 else:
802 if torch.eq(dim_tensor1, 2):
803 _22 = torch.eq(dim_tensor2, 2)
804 else:
805 _22 = False
806 if _22:
807 if torch.eq(torch.len(input), 2):
808 pass
809 else:
810 ops.prim.RaiseException(_0)
811 if torch.eq(torch.len(_5), 2):
812 pass
813 else:
814 ops.prim.RaiseException(_1)
815 if torch.eq(input[1], _5[0]):
816 pass
817 else:
818 ops.prim.RaiseException("AssertionError: ")
819 _23 = [input[0], _5[1]]
820 else:
821 if torch.ge(dim_tensor1, 1):
822 _24 = torch.ge(dim_tensor2, 1)
823 else:
824 _24 = False
825 if _24:
826 if torch.gt(dim_tensor1, 1):
827 n = input[-2]
828 else:
829 n = 1
830 batch_tensor1 = annotate(List[int], [])
831 for i0 in range(torch.sub(dim_tensor1, 2)):
832 _26 = torch.append(batch_tensor1, input[i0])
833 p = _5[-1]
834 batch_tensor2 = annotate(List[int], [])
835 for i1 in range(torch.sub(dim_tensor2, 2)):
836 _27 = torch.append(batch_tensor2, _5[i1])
837 dimsA = torch.len(batch_tensor1)
838 dimsB = torch.len(batch_tensor2)
839 ndim = ops.prim.max(dimsA, dimsB)
840 expand_batch_portion = annotate(List[int], [])
841 for i2 in range(ndim):
842 offset = torch.sub(torch.sub(ndim, 1), i2)
843 dimA = torch.sub(torch.sub(dimsA, 1), offset)
844 dimB = torch.sub(torch.sub(dimsB, 1), offset)
845 if torch.ge(dimA, 0):
846 sizeA = batch_tensor1[dimA]
847 else:
848 sizeA = 1
849 if torch.ge(dimB, 0):
850 sizeB = batch_tensor2[dimB]
851 else:
852 sizeB = 1
853 if torch.ne(sizeA, sizeB):
854 _28 = torch.ne(sizeA, 1)
855 else:
856 _28 = False
857 if _28:
858 _29 = torch.ne(sizeB, 1)
859 else:
860 _29 = False
861 if _29:
862 _30 = torch.format(_2, sizeA, sizeB, i2)
863 _31 = torch.add("AssertionError: ", _30)
864 ops.prim.RaiseException(_31)
865 else:
866 pass
867 if torch.eq(sizeA, 1):
868 _32 = sizeB
869 else:
870 _32 = sizeA
871 _33 = torch.append(expand_batch_portion, _32)
872 if torch.gt(dim_tensor1, 1):
873 _34 = torch.append(expand_batch_portion, n)
874 else:
875 pass
876 if torch.gt(dim_tensor2, 1):
877 _35 = torch.append(expand_batch_portion, p)
878 else:
879 pass
880 _25 = expand_batch_portion
881 else:
882 ops.prim.RaiseException(_3)
883 _25 = _7
884 _23 = _25
885 _14 = _23
886 _11 = _14
887 out = _11
888 if torch.__isnot__(bias, None):
889 bias0 = unchecked_cast(List[int], bias)
890 dimsA0 = torch.len(bias0)
891 dimsB0 = torch.len(out)
892 ndim0 = ops.prim.max(dimsA0, dimsB0)
893 expandedSizes = annotate(List[int], [])
894 for i3 in range(ndim0):
895 offset0 = torch.sub(torch.sub(ndim0, 1), i3)
896 dimA0 = torch.sub(torch.sub(dimsA0, 1), offset0)
897 dimB0 = torch.sub(torch.sub(dimsB0, 1), offset0)
898 if torch.ge(dimA0, 0):
899 sizeA0 = bias0[dimA0]
900 else:
901 sizeA0 = 1
902 if torch.ge(dimB0, 0):
903 sizeB0 = out[dimB0]
904 else:
905 sizeB0 = 1
906 if torch.ne(sizeA0, sizeB0):
907 _36 = torch.ne(sizeA0, 1)
908 else:
909 _36 = False
910 if _36:
911 _37 = torch.ne(sizeB0, 1)
912 else:
913 _37 = False
914 if _37:
915 _38 = torch.format(_4, sizeA0, sizeB0, i3)
916 _39 = torch.add("AssertionError: ", _38)
917 ops.prim.RaiseException(_39)
918 else:
919 pass
920 if torch.eq(sizeA0, 1):
921 _40 = sizeB0
922 else:
923 _40 = sizeA0
924 _41 = torch.append(expandedSizes, _40)
925 if torch.eq(expandedSizes, out):
926 pass
927 else:
928 ops.prim.RaiseException("AssertionError: ")
929 else:
930 pass
931 return out
932
933 )=====")
934 + std::string(R"=====(def max_pool2d(input: List[int],
935 kernel_size: List[int],
936 stride: List[int],
937 padding: List[int],
938 dilation: List[int],
939 ceil_mode: bool) -> List[int]:
940 _0 = "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
941 _1 = "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
942 _2 = "AssertionError: max_pool2d: padding must either be a single int, or a tuple of two ints"
943 _3 = "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints"
944 _4 = "AssertionError: stride should not be zeero"
945 _5 = "AssertionError: stride should not be zeero"
946 if torch.eq(torch.len(kernel_size), 1):
947 _6 = True
948 else:
949 _6 = torch.eq(torch.len(kernel_size), 2)
950 if _6:
951 pass
952 else:
953 ops.prim.RaiseException(_0)
954 kH = kernel_size[0]
955 if torch.eq(torch.len(kernel_size), 1):
956 kW = kH
957 else:
958 kW = kernel_size[1]
959 if torch.eq(torch.len(stride), 0):
960 _7 = True
961 else:
962 _7 = torch.eq(torch.len(stride), 1)
963 if _7:
964 _8 = True
965 else:
966 _8 = torch.eq(torch.len(stride), 2)
967 if _8:
968 pass
969 else:
970 ops.prim.RaiseException(_1)
971 if torch.eq(torch.len(stride), 0):
972 dH = kH
973 else:
974 dH = stride[0]
975 if torch.eq(torch.len(stride), 0):
976 dW = kW
977 else:
978 if torch.eq(torch.len(stride), 1):
979 dW0 = dH
980 else:
981 dW0 = stride[1]
982 dW = dW0
983 if torch.eq(torch.len(padding), 1):
984 _9 = True
985 else:
986 _9 = torch.eq(torch.len(padding), 2)
987 if _9:
988 pass
989 else:
990 ops.prim.RaiseException(_2)
991 padH = padding[0]
992 if torch.eq(torch.len(padding), 1):
993 padW = padH
994 else:
995 padW = padding[1]
996 if torch.eq(torch.len(dilation), 1):
997 _10 = True
998 else:
999 _10 = torch.eq(torch.len(dilation), 2)
1000 if _10:
1001 pass
1002 else:
1003 ops.prim.RaiseException(_3)
1004 dilationH = dilation[0]
1005 if torch.eq(torch.len(dilation), 1):
1006 dilationW = dilationH
1007 else:
1008 dilationW = dilation[1]
1009 if torch.eq(torch.len(input), 3):
1010 _11 = True
1011 else:
1012 _11 = torch.eq(torch.len(input), 4)
1013 if _11:
1014 pass
1015 else:
1016 ops.prim.RaiseException("AssertionError: ")
1017 if torch.eq(torch.len(input), 4):
1018 nbatch = input[-4]
1019 else:
1020 nbatch = 1
1021 nInputPlane = input[-3]
1022 inputHeight = input[-2]
1023 inputWidth = input[-1]
1024 if torch.ne(dH, 0):
1025 pass
1026 else:
1027 ops.prim.RaiseException(_4)
1028 _12 = torch.add(torch.add(inputHeight, padH), padH)
1029 _13 = torch.mul(dilationH, torch.sub(kH, 1))
1030 _14 = torch.sub(torch.sub(_12, _13), 1)
1031 if ceil_mode:
1032 _15 = torch.sub(dH, 1)
1033 else:
1034 _15 = 0
1035 _16 = torch.floordiv(torch.add(_14, _15), dH)
1036 outputSize = torch.add(_16, 1)
1037 if ceil_mode:
1038 _17 = torch.ge(torch.mul(_16, dH), torch.add(inputHeight, padH))
1039 if _17:
1040 outputSize0 = _16
1041 else:
1042 outputSize0 = outputSize
1043 outputHeight = outputSize0
1044 else:
1045 outputHeight = outputSize
1046 if torch.ne(dW, 0):
1047 pass
1048 else:
1049 ops.prim.RaiseException(_5)
1050 _18 = torch.add(torch.add(inputWidth, padW), padW)
1051 _19 = torch.mul(dilationW, torch.sub(kW, 1))
1052 _20 = torch.sub(torch.sub(_18, _19), 1)
1053 if ceil_mode:
1054 _21 = torch.sub(dW, 1)
1055 else:
1056 _21 = 0
1057 _22 = torch.floordiv(torch.add(_20, _21), dW)
1058 outputSize1 = torch.add(_22, 1)
1059 if ceil_mode:
1060 _23 = torch.ge(torch.mul(_22, dW), torch.add(inputWidth, padW))
1061 if _23:
1062 outputSize2 = _22
1063 else:
1064 outputSize2 = outputSize1
1065 outputWidth = outputSize2
1066 else:
1067 outputWidth = outputSize1
1068 ndim = torch.len(input)
1069 if torch.gt(kW, 0):
1070 _24 = torch.gt(kH, 0)
1071 else:
1072 _24 = False
1073 if _24:
1074 pass
1075 else:
1076 ops.prim.RaiseException("AssertionError: ")
1077 if torch.gt(dW, 0):
1078 _25 = torch.gt(dH, 0)
1079 else:
1080 _25 = False
1081 if _25:
1082 pass
1083 else:
1084 ops.prim.RaiseException("AssertionError: ")
1085 if torch.gt(dilationH, 0):
1086 _26 = torch.gt(dilationW, 0)
1087 else:
1088 _26 = False
1089 if _26:
1090 pass
1091 else:
1092 ops.prim.RaiseException("AssertionError: ")
1093 if torch.ne(input[1], 0):
1094 valid_dims = torch.ne(input[2], 0)
1095 else:
1096 valid_dims = False
1097 if torch.eq(ndim, 3):
1098 _27 = torch.ne(input[0], 0)
1099 else:
1100 _27 = False
1101 if _27:
1102 _28 = valid_dims
1103 else:
1104 _28 = False
1105 if _28:
1106 _29 = True
1107 else:
1108 if torch.eq(ndim, 4):
1109 _30 = valid_dims
1110 else:
1111 _30 = False
1112 if _30:
1113 _31 = torch.ne(input[3], 0)
1114 else:
1115 _31 = False
1116 _29 = _31
1117 if _29:
1118 pass
1119 else:
1120 ops.prim.RaiseException("AssertionError: ")
1121 if torch.ge(torch.floordiv(kW, 2), padW):
1122 _33 = torch.ge(torch.floordiv(kH, 2), padH)
1123 _32 = _33
1124 else:
1125 _32 = False
1126 if _32:
1127 pass
1128 else:
1129 ops.prim.RaiseException("AssertionError: ")
1130 if torch.ge(outputWidth, 1):
1131 _34 = torch.ge(outputHeight, 1)
1132 else:
1133 _34 = False
1134 if _34:
1135 pass
1136 else:
1137 ops.prim.RaiseException("AssertionError: ")
1138 if torch.eq(torch.len(input), 3):
1139 _36 = [nInputPlane, outputHeight, outputWidth]
1140 _35 = _36
1141 else:
1142 _37 = [nbatch, nInputPlane, outputHeight, outputWidth]
1143 _35 = _37
1144 return _35
1145
1146 )=====")
1147 + std::string(R"=====(def max_pool2d_with_indices(input: List[int],
1148 kernel_size: List[int],
1149 stride: List[int],
1150 padding: List[int],
1151 dilation: List[int],
1152 ceil_mode: bool) -> Tuple[List[int], List[int]]:
1153 _0 = "AssertionError: max_pool2d: kernel_size must either be a single int, or a tuple of two ints"
1154 _1 = "AssertionError: max_pool2d: stride must either be omitted, a single int, or a tuple of two ints"
1155 _2 = "AssertionError: max_pool2d: padding must either be a single int, or a tuple of two ints"
1156 _3 = "AssertionError: max_pool2d: dilation must be either a single int, or a tuple of two ints"
1157 _4 = "AssertionError: stride should not be zeero"
1158 if torch.eq(torch.len(kernel_size), 1):
1159 _5 = True
1160 else:
1161 _5 = torch.eq(torch.len(kernel_size), 2)
1162 if _5:
1163 pass
1164 else:
1165 ops.prim.RaiseException(_0)
1166 kH = kernel_size[0]
1167 if torch.eq(torch.len(kernel_size), 1):
1168 kW = kH
1169 else:
1170 kW = kernel_size[1]
1171 if torch.eq(torch.len(stride), 0):
1172 _6 = True
1173 else:
1174 _6 = torch.eq(torch.len(stride), 1)
1175 if _6:
1176 _7 = True
1177 else:
1178 _7 = torch.eq(torch.len(stride), 2)
1179 if _7:
1180 pass
1181 else:
1182 ops.prim.RaiseException(_1)
1183 if torch.eq(torch.len(stride), 0):
1184 dH = kH
1185 else:
1186 dH = stride[0]
1187 if torch.eq(torch.len(stride), 0):
1188 dW = kW
1189 else:
1190 if torch.eq(torch.len(stride), 1):
1191 dW0 = dH
1192 else:
1193 dW0 = stride[1]
1194 dW = dW0
1195 if torch.eq(torch.len(padding), 1):
1196 _8 = True
1197 else:
1198 _8 = torch.eq(torch.len(padding), 2)
1199 if _8:
1200 pass
1201 else:
1202 ops.prim.RaiseException(_2)
1203 padH = padding[0]
1204 if torch.eq(torch.len(padding), 1):
1205 padW = padH
1206 else:
1207 padW = padding[1]
1208 if torch.eq(torch.len(dilation), 1):
1209 _9 = True
1210 else:
1211 _9 = torch.eq(torch.len(dilation), 2)
1212 if _9:
1213 pass
1214 else:
1215 ops.prim.RaiseException(_3)
1216 dilationH = dilation[0]
1217 if torch.eq(torch.len(dilation), 1):
1218 dilationW = dilationH
1219 else:
1220 dilationW = dilation[1]
1221 if torch.eq(torch.len(input), 3):
1222 _10 = True
1223 else:
1224 _10 = torch.eq(torch.len(input), 4)
1225 if _10:
1226 pass
1227 else:
1228 ops.prim.RaiseException("AssertionError: ")
1229 if torch.eq(torch.len(input), 4):
1230 nbatch = input[-4]
1231 else:
1232 nbatch = 1
1233 nInputPlane = input[-3]
1234 inputHeight = input[-2]
1235 inputWidth = input[-1]
1236 if torch.ne(dH, 0):
1237 pass
1238 else:
1239 ops.prim.RaiseException(_4)
1240 _11 = torch.add(torch.add(inputHeight, padH), padH)
1241 _12 = torch.mul(dilationH, torch.sub(kH, 1))
1242 _13 = torch.sub(torch.sub(_11, _12), 1)
1243 if ceil_mode:
1244 _14 = torch.sub(dH, 1)
1245 else:
1246 _14 = 0
1247 _15 = torch.floordiv(torch.add(_13, _14), dH)
1248 outputSize = torch.add(_15, 1)
1249 if ceil_mode:
1250 _16 = torch.ge(torch.mul(_15, dH), torch.add(inputHeight, padH))
1251 if _16:
1252 outputSize0 = _15
1253 else:
1254 outputSize0 = outputSize
1255 outputHeight = outputSize0
1256 else:
1257 outputHeight = outputSize
1258 if torch.ne(dW, 0):
1259 pass
1260 else:
1261 ops.prim.RaiseException(_4)
1262 _17 = torch.add(torch.add(inputWidth, padW), padW)
1263 _18 = torch.mul(dilationW, torch.sub(kW, 1))
1264 _19 = torch.sub(torch.sub(_17, _18), 1)
1265 if ceil_mode:
1266 _20 = torch.sub(dW, 1)
1267 else:
1268 _20 = 0
1269 _21 = torch.floordiv(torch.add(_19, _20), dW)
1270 outputSize1 = torch.add(_21, 1)
1271 if ceil_mode:
1272 _22 = torch.ge(torch.mul(_21, dW), torch.add(inputWidth, padW))
1273 if _22:
1274 outputSize2 = _21
1275 else:
1276 outputSize2 = outputSize1
1277 outputWidth = outputSize2
1278 else:
1279 outputWidth = outputSize1
1280 ndim = torch.len(input)
1281 if torch.gt(kW, 0):
1282 _23 = torch.gt(kH, 0)
1283 else:
1284 _23 = False
1285 if _23:
1286 pass
1287 else:
1288 ops.prim.RaiseException("AssertionError: ")
1289 if torch.gt(dW, 0):
1290 _24 = torch.gt(dH, 0)
1291 else:
1292 _24 = False
1293 if _24:
1294 pass
1295 else:
1296 ops.prim.RaiseException("AssertionError: ")
1297 if torch.gt(dilationH, 0):
1298 _25 = torch.gt(dilationW, 0)
1299 else:
1300 _25 = False
1301 if _25:
1302 pass
1303 else:
1304 ops.prim.RaiseException("AssertionError: ")
1305 if torch.ne(input[1], 0):
1306 valid_dims = torch.ne(input[2], 0)
1307 else:
1308 valid_dims = False
1309 if torch.eq(ndim, 3):
1310 _26 = torch.ne(input[0], 0)
1311 else:
1312 _26 = False
1313 if _26:
1314 _27 = valid_dims
1315 else:
1316 _27 = False
1317 if _27:
1318 _28 = True
1319 else:
1320 if torch.eq(ndim, 4):
1321 _29 = valid_dims
1322 else:
1323 _29 = False
1324 if _29:
1325 _30 = torch.ne(input[3], 0)
1326 else:
1327 _30 = False
1328 _28 = _30
1329 if _28:
1330 pass
1331 else:
1332 ops.prim.RaiseException("AssertionError: ")
1333 if torch.ge(torch.floordiv(kW, 2), padW):
1334 _32 = torch.ge(torch.floordiv(kH, 2), padH)
1335 _31 = _32
1336 else:
1337 _31 = False
1338 if _31:
1339 pass
1340 else:
1341 ops.prim.RaiseException("AssertionError: ")
1342 if torch.ge(outputWidth, 1):
1343 _33 = torch.ge(outputHeight, 1)
1344 else:
1345 _33 = False
1346 if _33:
1347 pass
1348 else:
1349 ops.prim.RaiseException("AssertionError: ")
1350 if torch.eq(torch.len(input), 3):
1351 _34 = [nInputPlane, outputHeight, outputWidth]
1352 out = _34
1353 else:
1354 _35 = [nbatch, nInputPlane, outputHeight, outputWidth]
1355 out = _35
1356 return (out, out)
1357
1358 )=====")
1359 + std::string(R"=====(def t(self: List[int]) -> List[int]:
1360 if torch.le(torch.len(self), 2):
1361 pass
1362 else:
1363 ops.prim.RaiseException("AssertionError: ")
1364 self_len = torch.len(self)
1365 if torch.eq(self_len, 0):
1366 _0 = annotate(List[int], [])
1367 else:
1368 if torch.eq(self_len, 1):
1369 _1 = [self[0]]
1370 else:
1371 _1 = [self[1], self[0]]
1372 _0 = _1
1373 return _0
1374
1375 def transpose(self: List[int],
1376 dim0: int,
1377 dim1: int) -> List[int]:
1378 ndims = torch.len(self)
1379 if torch.le(ndims, 0):
1380 dim_post_expr = 1
1381 else:
1382 dim_post_expr = ndims
1383 min = torch.neg(dim_post_expr)
1384 max = torch.sub(dim_post_expr, 1)
1385 if torch.lt(dim0, min):
1386 _0 = True
1387 else:
1388 _0 = torch.gt(dim0, max)
1389 if torch.__not__(_0):
1390 pass
1391 else:
1392 ops.prim.RaiseException("AssertionError: ")
1393 if torch.lt(dim0, 0):
1394 dim00 = torch.add(dim0, dim_post_expr)
1395 else:
1396 dim00 = dim0
1397 if torch.le(ndims, 0):
1398 dim_post_expr0 = 1
1399 else:
1400 dim_post_expr0 = ndims
1401 min0 = torch.neg(dim_post_expr0)
1402 max0 = torch.sub(dim_post_expr0, 1)
1403 if torch.lt(dim1, min0):
1404 _1 = True
1405 else:
1406 _1 = torch.gt(dim1, max0)
1407 if torch.__not__(_1):
1408 pass
1409 else:
1410 ops.prim.RaiseException("AssertionError: ")
1411 if torch.lt(dim1, 0):
1412 dim10 = torch.add(dim1, dim_post_expr0)
1413 else:
1414 dim10 = dim1
1415 if torch.eq(dim00, dim10):
1416 out = annotate(List[int], [])
1417 for _3 in range(torch.len(self)):
1418 elem = self[_3]
1419 _4 = torch.append(out, elem)
1420 _2 = out
1421 else:
1422 out0 = annotate(List[int], [])
1423 for i in range(ndims):
1424 if torch.eq(i, dim00):
1425 _5 = torch.append(out0, self[dim10])
1426 else:
1427 if torch.eq(i, dim10):
1428 _6 = torch.append(out0, self[dim00])
1429 else:
1430 _7 = torch.append(out0, self[i])
1431 _2 = out0
1432 return _2
1433
1434 )=====")
1435 + std::string(R"=====(def conv1d(input: List[int],
1436 weight: List[int],
1437 bias: Optional[List[int]],
1438 stride: List[int],
1439 padding: List[int],
1440 dilation: List[int],
1441 groups: int) -> List[int]:
1442 if torch.eq(torch.len(weight), 3):
1443 pass
1444 else:
1445 ops.prim.RaiseException("AssertionError: ")
1446 if torch.eq(torch.len(input), 3):
1447 pass
1448 else:
1449 ops.prim.RaiseException("AssertionError: ")
1450 k = torch.len(input)
1451 weight_dim = torch.len(weight)
1452 non_negative = False
1453 for _0 in range(torch.len(padding)):
1454 val = padding[_0]
1455 if torch.lt(val, 0):
1456 non_negative0 = True
1457 else:
1458 non_negative0 = non_negative
1459 non_negative = non_negative0
1460 if torch.__not__(non_negative):
1461 pass
1462 else:
1463 ops.prim.RaiseException("AssertionError: ")
1464 non_negative1 = False
1465 for _1 in range(torch.len(stride)):
1466 val0 = stride[_1]
1467 if torch.lt(val0, 0):
1468 non_negative2 = True
1469 else:
1470 non_negative2 = non_negative1
1471 non_negative1 = non_negative2
1472 if torch.__not__(non_negative1):
1473 pass
1474 else:
1475 ops.prim.RaiseException("AssertionError: ")
1476 if torch.eq(weight_dim, k):
1477 pass
1478 else:
1479 ops.prim.RaiseException("AssertionError: ")
1480 if torch.ge(weight[0], groups):
1481 pass
1482 else:
1483 ops.prim.RaiseException("AssertionError: ")
1484 _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1485 if _2:
1486 pass
1487 else:
1488 ops.prim.RaiseException("AssertionError: ")
1489 _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1490 if _3:
1491 pass
1492 else:
1493 ops.prim.RaiseException("AssertionError: ")
1494 if torch.__is__(bias, None):
1495 _4 = True
1496 else:
1497 bias0 = unchecked_cast(List[int], bias)
1498 if torch.eq(torch.len(bias0), 1):
1499 _5 = torch.eq(bias0[0], weight[0])
1500 else:
1501 _5 = False
1502 _4 = _5
1503 if _4:
1504 pass
1505 else:
1506 ops.prim.RaiseException("AssertionError: ")
1507 for _6 in range(torch.__range_length(2, k, 1)):
1508 i = torch.__derive_index(_6, 2, 1)
1509 _7 = input[i]
1510 _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1511 _9 = torch.add(_7, _8)
1512 _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1513 if torch.ge(_9, torch.add(_10, 1)):
1514 pass
1515 else:
1516 ops.prim.RaiseException("AssertionError: ")
1517 has_dilation = torch.gt(torch.len(dilation), 0)
1518 dim = torch.len(input)
1519 output_size = annotate(List[int], [])
1520 _11 = torch.append(output_size, input[0])
1521 _12 = torch.append(output_size, weight[0])
1522 for _13 in range(torch.__range_length(2, dim, 1)):
1523 d = torch.__derive_index(_13, 2, 1)
1524 if has_dilation:
1525 dilation_ = dilation[torch.sub(d, 2)]
1526 else:
1527 dilation_ = 1
1528 _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1529 kernel = torch.add(_14, 1)
1530 _15 = input[d]
1531 _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1532 _17 = torch.sub(torch.add(_15, _16), kernel)
1533 _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1534 _19 = torch.append(output_size, torch.add(_18, 1))
1535 return output_size
1536
1537 )=====")
1538 + std::string(R"=====(def conv2d(input: List[int],
1539 weight: List[int],
1540 bias: Optional[List[int]],
1541 stride: List[int],
1542 padding: List[int],
1543 dilation: List[int],
1544 groups: int) -> List[int]:
1545 if torch.eq(torch.len(weight), 4):
1546 pass
1547 else:
1548 ops.prim.RaiseException("AssertionError: ")
1549 if torch.eq(torch.len(input), 4):
1550 pass
1551 else:
1552 ops.prim.RaiseException("AssertionError: ")
1553 k = torch.len(input)
1554 weight_dim = torch.len(weight)
1555 non_negative = False
1556 for _0 in range(torch.len(padding)):
1557 val = padding[_0]
1558 if torch.lt(val, 0):
1559 non_negative0 = True
1560 else:
1561 non_negative0 = non_negative
1562 non_negative = non_negative0
1563 if torch.__not__(non_negative):
1564 pass
1565 else:
1566 ops.prim.RaiseException("AssertionError: ")
1567 non_negative1 = False
1568 for _1 in range(torch.len(stride)):
1569 val0 = stride[_1]
1570 if torch.lt(val0, 0):
1571 non_negative2 = True
1572 else:
1573 non_negative2 = non_negative1
1574 non_negative1 = non_negative2
1575 if torch.__not__(non_negative1):
1576 pass
1577 else:
1578 ops.prim.RaiseException("AssertionError: ")
1579 if torch.eq(weight_dim, k):
1580 pass
1581 else:
1582 ops.prim.RaiseException("AssertionError: ")
1583 if torch.ge(weight[0], groups):
1584 pass
1585 else:
1586 ops.prim.RaiseException("AssertionError: ")
1587 _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1588 if _2:
1589 pass
1590 else:
1591 ops.prim.RaiseException("AssertionError: ")
1592 _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1593 if _3:
1594 pass
1595 else:
1596 ops.prim.RaiseException("AssertionError: ")
1597 if torch.__is__(bias, None):
1598 _4 = True
1599 else:
1600 bias0 = unchecked_cast(List[int], bias)
1601 if torch.eq(torch.len(bias0), 1):
1602 _5 = torch.eq(bias0[0], weight[0])
1603 else:
1604 _5 = False
1605 _4 = _5
1606 if _4:
1607 pass
1608 else:
1609 ops.prim.RaiseException("AssertionError: ")
1610 for _6 in range(torch.__range_length(2, k, 1)):
1611 i = torch.__derive_index(_6, 2, 1)
1612 _7 = input[i]
1613 _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1614 _9 = torch.add(_7, _8)
1615 _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1616 if torch.ge(_9, torch.add(_10, 1)):
1617 pass
1618 else:
1619 ops.prim.RaiseException("AssertionError: ")
1620 has_dilation = torch.gt(torch.len(dilation), 0)
1621 dim = torch.len(input)
1622 output_size = annotate(List[int], [])
1623 _11 = torch.append(output_size, input[0])
1624 _12 = torch.append(output_size, weight[0])
1625 for _13 in range(torch.__range_length(2, dim, 1)):
1626 d = torch.__derive_index(_13, 2, 1)
1627 if has_dilation:
1628 dilation_ = dilation[torch.sub(d, 2)]
1629 else:
1630 dilation_ = 1
1631 _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1632 kernel = torch.add(_14, 1)
1633 _15 = input[d]
1634 _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1635 _17 = torch.sub(torch.add(_15, _16), kernel)
1636 _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1637 _19 = torch.append(output_size, torch.add(_18, 1))
1638 return output_size
1639
1640 )=====")
1641 + std::string(R"=====(def batch_norm(input: List[int],
1642 weight: Optional[List[int]],
1643 bias: Optional[List[int]],
1644 running_mean: Optional[List[int]],
1645 running_var: Optional[List[int]],
1646 training: bool,
1647 momentum: float,
1648 eps: float,
1649 cudnn_enabled: bool) -> List[int]:
1650 out = annotate(List[int], [])
1651 for _0 in range(torch.len(input)):
1652 elem = input[_0]
1653 _1 = torch.append(out, elem)
1654 return out
1655
1656 )=====")
1657 + std::string(R"=====(def conv3d(input: List[int],
1658 weight: List[int],
1659 bias: Optional[List[int]],
1660 stride: List[int],
1661 padding: List[int],
1662 dilation: List[int],
1663 groups: int) -> List[int]:
1664 if torch.eq(torch.len(weight), 5):
1665 pass
1666 else:
1667 ops.prim.RaiseException("AssertionError: ")
1668 if torch.eq(torch.len(input), 5):
1669 pass
1670 else:
1671 ops.prim.RaiseException("AssertionError: ")
1672 k = torch.len(input)
1673 weight_dim = torch.len(weight)
1674 non_negative = False
1675 for _0 in range(torch.len(padding)):
1676 val = padding[_0]
1677 if torch.lt(val, 0):
1678 non_negative0 = True
1679 else:
1680 non_negative0 = non_negative
1681 non_negative = non_negative0
1682 if torch.__not__(non_negative):
1683 pass
1684 else:
1685 ops.prim.RaiseException("AssertionError: ")
1686 non_negative1 = False
1687 for _1 in range(torch.len(stride)):
1688 val0 = stride[_1]
1689 if torch.lt(val0, 0):
1690 non_negative2 = True
1691 else:
1692 non_negative2 = non_negative1
1693 non_negative1 = non_negative2
1694 if torch.__not__(non_negative1):
1695 pass
1696 else:
1697 ops.prim.RaiseException("AssertionError: ")
1698 if torch.eq(weight_dim, k):
1699 pass
1700 else:
1701 ops.prim.RaiseException("AssertionError: ")
1702 if torch.ge(weight[0], groups):
1703 pass
1704 else:
1705 ops.prim.RaiseException("AssertionError: ")
1706 _2 = torch.eq(torch.remainder(weight[0], groups), 0)
1707 if _2:
1708 pass
1709 else:
1710 ops.prim.RaiseException("AssertionError: ")
1711 _3 = torch.eq(input[1], torch.mul(weight[1], groups))
1712 if _3:
1713 pass
1714 else:
1715 ops.prim.RaiseException("AssertionError: ")
1716 if torch.__is__(bias, None):
1717 _4 = True
1718 else:
1719 bias0 = unchecked_cast(List[int], bias)
1720 if torch.eq(torch.len(bias0), 1):
1721 _5 = torch.eq(bias0[0], weight[0])
1722 else:
1723 _5 = False
1724 _4 = _5
1725 if _4:
1726 pass
1727 else:
1728 ops.prim.RaiseException("AssertionError: ")
1729 for _6 in range(torch.__range_length(2, k, 1)):
1730 i = torch.__derive_index(_6, 2, 1)
1731 _7 = input[i]
1732 _8 = torch.mul(padding[torch.sub(i, 2)], 2)
1733 _9 = torch.add(_7, _8)
1734 _10 = torch.mul(dilation[torch.sub(i, 2)], torch.sub(weight[i], 1))
1735 if torch.ge(_9, torch.add(_10, 1)):
1736 pass
1737 else:
1738 ops.prim.RaiseException("AssertionError: ")
1739 has_dilation = torch.gt(torch.len(dilation), 0)
1740 dim = torch.len(input)
1741 output_size = annotate(List[int], [])
1742 _11 = torch.append(output_size, input[0])
1743 _12 = torch.append(output_size, weight[0])
1744 for _13 in range(torch.__range_length(2, dim, 1)):
1745 d = torch.__derive_index(_13, 2, 1)
1746 if has_dilation:
1747 dilation_ = dilation[torch.sub(d, 2)]
1748 else:
1749 dilation_ = 1
1750 _14 = torch.mul(dilation_, torch.sub(weight[d], 1))
1751 kernel = torch.add(_14, 1)
1752 _15 = input[d]
1753 _16 = torch.mul(padding[torch.sub(d, 2)], 2)
1754 _17 = torch.sub(torch.add(_15, _16), kernel)
1755 _18 = torch.floordiv(_17, stride[torch.sub(d, 2)])
1756 _19 = torch.append(output_size, torch.add(_18, 1))
1757 return output_size
1758
1759 )=====")
1760 + std::string(R"=====(def conv_backwards(grad_output: List[int],
1761 input: List[int],
1762 weight: List[int],
1763 biases: Optional[List[int]]) -> Tuple[List[int], List[int], List[int]]:
1764 out = annotate(List[int], [])
1765 for _0 in range(torch.len(input)):
1766 elem = input[_0]
1767 _1 = torch.append(out, elem)
1768 out0 = annotate(List[int], [])
1769 for _2 in range(torch.len(weight)):
1770 elem0 = weight[_2]
1771 _3 = torch.append(out0, elem0)
1772 return (out, out0, [grad_output[1]])
1773
1774 )=====")
1775 + std::string(R"=====(def conv_forwards(input: List[int],
1776 weight: List[int],
1777 bias: Optional[List[int]],
1778 stride: List[int],
1779 padding: List[int],
1780 dilation: List[int],
1781 transposed: bool,
1782 output_padding: List[int],
1783 groups: int) -> List[int]:
1784 has_dilation = torch.gt(torch.len(dilation), 0)
1785 has_output_padding = torch.gt(torch.len(output_padding), 0)
1786 dim = torch.len(input)
1787 output_size = annotate(List[int], [])
1788 if transposed:
1789 weight_output_channels_dim = 1
1790 else:
1791 weight_output_channels_dim = 0
1792 _0 = torch.append(output_size, input[0])
1793 if transposed:
1794 _1 = torch.mul(weight[weight_output_channels_dim], groups)
1795 _2 = torch.append(output_size, _1)
1796 else:
1797 _3 = torch.append(output_size, weight[weight_output_channels_dim])
1798 for _4 in range(torch.__range_length(2, dim, 1)):
1799 d = torch.__derive_index(_4, 2, 1)
1800 if has_dilation:
1801 dilation_ = dilation[torch.sub(d, 2)]
1802 else:
1803 dilation_ = 1
1804 if has_output_padding:
1805 output_padding_ = output_padding[torch.sub(d, 2)]
1806 else:
1807 output_padding_ = 0
1808 if transposed:
1809 kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1810 _5 = torch.mul(torch.sub(input[d], 1), stride[torch.sub(d, 2)])
1811 _6 = torch.mul(padding[torch.sub(d, 2)], 2)
1812 _7 = torch.add(torch.sub(_5, _6), kernel)
1813 _8 = torch.add(torch.add(_7, output_padding_), 1)
1814 _9 = torch.append(output_size, _8)
1815 else:
1816 _10 = torch.mul(dilation_, torch.sub(weight[d], 1))
1817 kernel0 = torch.add(_10, 1)
1818 _11 = input[d]
1819 _12 = torch.mul(padding[torch.sub(d, 2)], 2)
1820 _13 = torch.sub(torch.add(_11, _12), kernel0)
1821 _14 = torch.floordiv(_13, stride[torch.sub(d, 2)])
1822 _15 = torch.append(output_size, torch.add(_14, 1))
1823 return output_size
1824
1825 )=====")
1826 + std::string(R"=====(def _conv_forwards(input: List[int],
1827 weight: List[int],
1828 bias: Optional[List[int]],
1829 stride: List[int],
1830 padding: List[int],
1831 dilation: List[int],
1832 transposed: bool,
1833 output_padding: List[int],
1834 groups: int,
1835 benchmark: bool,
1836 deterministic: bool,
1837 cudnn_enabled: bool,
1838 allow_tf32: bool) -> List[int]:
1839 has_dilation = torch.gt(torch.len(dilation), 0)
1840 has_output_padding = torch.gt(torch.len(output_padding), 0)
1841 dim = torch.len(input)
1842 output_size = annotate(List[int], [])
1843 if transposed:
1844 weight_output_channels_dim = 1
1845 else:
1846 weight_output_channels_dim = 0
1847 _0 = torch.append(output_size, input[0])
1848 if transposed:
1849 _1 = torch.mul(weight[weight_output_channels_dim], groups)
1850 _2 = torch.append(output_size, _1)
1851 else:
1852 _3 = torch.append(output_size, weight[weight_output_channels_dim])
1853 for _4 in range(torch.__range_length(2, dim, 1)):
1854 d = torch.__derive_index(_4, 2, 1)
1855 if has_dilation:
1856 dilation_ = dilation[torch.sub(d, 2)]
1857 else:
1858 dilation_ = 1
1859 if has_output_padding:
1860 output_padding_ = output_padding[torch.sub(d, 2)]
1861 else:
1862 output_padding_ = 0
1863 if transposed:
1864 kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1865 _5 = torch.mul(torch.sub(input[d], 1), stride[torch.sub(d, 2)])
1866 _6 = torch.mul(padding[torch.sub(d, 2)], 2)
1867 _7 = torch.add(torch.sub(_5, _6), kernel)
1868 _8 = torch.add(torch.add(_7, output_padding_), 1)
1869 _9 = torch.append(output_size, _8)
1870 else:
1871 _10 = torch.mul(dilation_, torch.sub(weight[d], 1))
1872 kernel0 = torch.add(_10, 1)
1873 _11 = input[d]
1874 _12 = torch.mul(padding[torch.sub(d, 2)], 2)
1875 _13 = torch.sub(torch.add(_11, _12), kernel0)
1876 _14 = torch.floordiv(_13, stride[torch.sub(d, 2)])
1877 _15 = torch.append(output_size, torch.add(_14, 1))
1878 return output_size
1879
1880 )=====")
1881 + std::string(R"=====(def conv_transpose2d_input(input: List[int],
1882 weight: List[int],
1883 bias: Optional[List[int]]=None,
1884 stride: Optional[List[int]]=None,
1885 padding: Optional[List[int]]=None,
1886 output_padding: Optional[List[int]]=None,
1887 groups: int=1,
1888 dilation: Optional[List[int]]=None) -> List[int]:
1889 if torch.__is__(stride, None):
1890 stride0 = [1, 1]
1891 else:
1892 stride0 = unchecked_cast(List[int], stride)
1893 if torch.__is__(padding, None):
1894 padding0 = [0, 0]
1895 else:
1896 padding0 = unchecked_cast(List[int], padding)
1897 if torch.__is__(output_padding, None):
1898 output_padding0 = [0, 0]
1899 else:
1900 output_padding1 = unchecked_cast(List[int], output_padding)
1901 output_padding0 = output_padding1
1902 if torch.__is__(dilation, None):
1903 dilation0 = [1, 1]
1904 else:
1905 dilation0 = unchecked_cast(List[int], dilation)
1906 has_dilation = torch.gt(torch.len(dilation0), 0)
1907 dim = torch.len(input)
1908 output_size = annotate(List[int], [])
1909 _0 = torch.append(output_size, input[0])
1910 _1 = torch.append(output_size, torch.mul(weight[1], groups))
1911 for _2 in range(torch.__range_length(2, dim, 1)):
1912 d = torch.__derive_index(_2, 2, 1)
1913 if has_dilation:
1914 dilation_ = dilation0[torch.sub(d, 2)]
1915 else:
1916 dilation_ = 1
1917 kernel = torch.mul(dilation_, torch.sub(weight[d], 1))
1918 _3 = torch.mul(torch.sub(input[d], 1), stride0[torch.sub(d, 2)])
1919 _4 = torch.mul(padding0[torch.sub(d, 2)], 2)
1920 _5 = torch.add(torch.sub(_3, _4), kernel)
1921 _6 = torch.add(_5, output_padding0[torch.sub(d, 2)])
1922 _7 = torch.append(output_size, torch.add(_6, 1))
1923 return output_size
1924
1925 )=====")
1926 + std::string(R"=====(def flatten(input: List[int],
1927 start_dim: int,
1928 end_dim: int) -> List[int]:
1929 _0 = torch.len(input)
1930 if torch.le(_0, 0):
1931 dim_post_expr = 1
1932 else:
1933 dim_post_expr = _0
1934 min = torch.neg(dim_post_expr)
1935 max = torch.sub(dim_post_expr, 1)
1936 if torch.lt(start_dim, min):
1937 _1 = True
1938 else:
1939 _1 = torch.gt(start_dim, max)
1940 if torch.__not__(_1):
1941 pass
1942 else:
1943 ops.prim.RaiseException("AssertionError: ")
1944 if torch.lt(start_dim, 0):
1945 start_dim0 = torch.add(start_dim, dim_post_expr)
1946 else:
1947 start_dim0 = start_dim
1948 _2 = torch.len(input)
1949 if torch.le(_2, 0):
1950 dim_post_expr0 = 1
1951 else:
1952 dim_post_expr0 = _2
1953 min0 = torch.neg(dim_post_expr0)
1954 max0 = torch.sub(dim_post_expr0, 1)
1955 if torch.lt(end_dim, min0):
1956 _3 = True
1957 else:
1958 _3 = torch.gt(end_dim, max0)
1959 if torch.__not__(_3):
1960 pass
1961 else:
1962 ops.prim.RaiseException("AssertionError: ")
1963 if torch.lt(end_dim, 0):
1964 end_dim0 = torch.add(end_dim, dim_post_expr0)
1965 else:
1966 end_dim0 = end_dim
1967 if torch.le(start_dim0, end_dim0):
1968 pass
1969 else:
1970 ops.prim.RaiseException("AssertionError: ")
1971 if torch.eq(torch.len(input), 0):
1972 _4 = [1]
1973 else:
1974 if torch.eq(start_dim0, end_dim0):
1975 out = annotate(List[int], [])
1976 for _6 in range(torch.len(input)):
1977 elem = input[_6]
1978 _7 = torch.append(out, elem)
1979 _5 = out
1980 else:
1981 _8 = torch.__range_length(start_dim0, torch.add(end_dim0, 1), 1)
1982 slice_numel = 1
1983 for _9 in range(_8):
1984 i = torch.__derive_index(_9, start_dim0, 1)
1985 slice_numel0 = torch.mul(slice_numel, input[i])
1986 slice_numel = slice_numel0
1987 shape = annotate(List[int], [])
1988 for i0 in range(start_dim0):
1989 _10 = torch.append(shape, input[i0])
1990 _11 = torch.append(shape, slice_numel)
1991 _12 = torch.add(end_dim0, 1)
1992 _13 = torch.__range_length(_12, torch.len(input), 1)
1993 for _14 in range(_13):
1994 i1 = torch.__derive_index(_14, _12, 1)
1995 _15 = torch.append(shape, input[i1])
1996 _5 = shape
1997 _4 = _5
1998 return _4
1999
2000 )=====")
2001 + std::string(R"=====(def cat(tensors: List[List[int]],
2002 dim: int) -> List[int]:
2003 _0 = "AssertionError: Tensors must have same number of dimensions"
2004 _1 = "AssertionError: Sizes of tensors must match except in dimension"
2005 for _2 in range(torch.len(tensors)):
2006 tensor = tensors[_2]
2007 if torch.gt(torch.len(tensor), 0):
2008 pass
2009 else:
2010 ops.prim.RaiseException("AssertionError: ")
2011 out_dim: Optional[int] = None
2012 for _3 in range(torch.len(tensors)):
2013 size = tensors[_3]
2014 if torch.eq(torch.len(size), 1):
2015 _4 = torch.eq(size[0], 0)
2016 else:
2017 _4 = False
2018 if torch.__not__(_4):
2019 if torch.__is__(out_dim, None):
2020 _5 = torch.len(size)
2021 if torch.le(_5, 0):
2022 dim_post_expr = 1
2023 else:
2024 dim_post_expr = _5
2025 min = torch.neg(dim_post_expr)
2026 max = torch.sub(dim_post_expr, 1)
2027 if torch.lt(dim, min):
2028 _6 = True
2029 else:
2030 _6 = torch.gt(dim, max)
2031 if torch.__not__(_6):
2032 pass
2033 else:
2034 ops.prim.RaiseException("AssertionError: ")
2035 if torch.lt(dim, 0):
2036 out_dim2 = torch.add(dim, dim_post_expr)
2037 else:
2038 out_dim2 = dim
2039 out_dim1 = out_dim2
2040 else:
2041 out_dim1 = unchecked_cast(int, out_dim)
2042 out_dim0 : Optional[int] = out_dim1
2043 else:
2044 out_dim0 = out_dim
2045 out_dim = out_dim0
2046 if torch.__is__(out_dim, None):
2047 dim0 = dim
2048 else:
2049 dim0 = unchecked_cast(int, out_dim)
2050 if torch.gt(torch.len(tensors), 0):
2051 pass
2052 else:
2053 ops.prim.RaiseException("AssertionError: ")
2054 not_skipped_tensor: Optional[List[int]] = None
2055 for _7 in range(torch.len(tensors)):
2056 tensor0 = tensors[_7]
2057 numel = 1
2058 for _8 in range(torch.len(tensor0)):
2059 elem = tensor0[_8]
2060 numel = torch.mul(numel, elem)
2061 if torch.eq(numel, 0):
2062 _9 = torch.eq(torch.len(tensor0), 1)
2063 else:
2064 _9 = False
2065 if torch.__not__(_9):
2066 not_skipped_tensor0 : Optional[List[int]] = tensor0
2067 else:
2068 not_skipped_tensor0 = not_skipped_tensor
2069 not_skipped_tensor = not_skipped_tensor0
2070 _10 = torch.__is__(not_skipped_tensor, None)
2071 if _10:
2072 _11 = [0]
2073 else:
2074 not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor)
2075 cat_dim_size = 0
2076 for i in range(torch.len(tensors)):
2077 tensor1 = tensors[i]
2078 numel0 = 1
2079 for _12 in range(torch.len(tensor1)):
2080 elem0 = tensor1[_12]
2081 numel0 = torch.mul(numel0, elem0)
2082 if torch.eq(numel0, 0):
2083 _13 = torch.eq(torch.len(tensor1), 1)
2084 else:
2085 _13 = False
2086 if torch.__not__(_13):
2087 first_dims = torch.len(not_skipped_tensor1)
2088 second_dims = torch.len(tensor1)
2089 _14 = torch.eq(first_dims, second_dims)
2090 if _14:
2091 pass
2092 else:
2093 ops.prim.RaiseException(_0)
2094 _15 = torch.__range_length(0, first_dims, 1)
2095 for _16 in range(_15):
2096 dim1 = torch.__derive_index(_16, 0, 1)
2097 if torch.ne(dim1, dim0):
2098 _17 = torch.eq(not_skipped_tensor1[dim1], tensor1[dim1])
2099 if _17:
2100 pass
2101 else:
2102 ops.prim.RaiseException(_1)
2103 else:
2104 pass
2105 cat_dim_size1 = torch.add(cat_dim_size, tensor1[dim0])
2106 cat_dim_size0 = cat_dim_size1
2107 else:
2108 cat_dim_size0 = cat_dim_size
2109 cat_dim_size = cat_dim_size0
2110 result_size = annotate(List[int], [])
2111 for _18 in range(torch.len(not_skipped_tensor1)):
2112 elem1 = not_skipped_tensor1[_18]
2113 _19 = torch.append(result_size, elem1)
2114 _20 = torch._set_item(result_size, dim0, cat_dim_size)
2115 _11 = result_size
2116 return _11
2117
2118 )=====")
2119 + std::string(R"=====(def stack(tensors: List[List[int]],
2120 dim: int) -> List[int]:
2121 _0 = "AssertionError: Tensors must have same number of dimensions"
2122 _1 = "AssertionError: Sizes of tensors must match except in dimension"
2123 unsqueezed_tensors = annotate(List[List[int]], [])
2124 for _2 in range(torch.len(tensors)):
2125 tensor = tensors[_2]
2126 _3 = torch.add(torch.len(tensor), 1)
2127 if torch.le(_3, 0):
2128 dim_post_expr = 1
2129 else:
2130 dim_post_expr = _3
2131 min = torch.neg(dim_post_expr)
2132 max = torch.sub(dim_post_expr, 1)
2133 if torch.lt(dim, min):
2134 _4 = True
2135 else:
2136 _4 = torch.gt(dim, max)
2137 if torch.__not__(_4):
2138 pass
2139 else:
2140 ops.prim.RaiseException("AssertionError: ")
2141 if torch.lt(dim, 0):
2142 dim0 = torch.add(dim, dim_post_expr)
2143 else:
2144 dim0 = dim
2145 unsqueezed = annotate(List[int], [])
2146 for _5 in range(torch.len(tensor)):
2147 elem = tensor[_5]
2148 _6 = torch.append(unsqueezed, elem)
2149 torch.insert(unsqueezed, dim0, 1)
2150 _7 = torch.append(unsqueezed_tensors, unsqueezed)
2151 for _8 in range(torch.len(unsqueezed_tensors)):
2152 tensor0 = unsqueezed_tensors[_8]
2153 if torch.gt(torch.len(tensor0), 0):
2154 pass
2155 else:
2156 ops.prim.RaiseException("AssertionError: ")
2157 out_dim: Optional[int] = None
2158 for _9 in range(torch.len(unsqueezed_tensors)):
2159 size = unsqueezed_tensors[_9]
2160 if torch.eq(torch.len(size), 1):
2161 _10 = torch.eq(size[0], 0)
2162 else:
2163 _10 = False
2164 if torch.__not__(_10):
2165 if torch.__is__(out_dim, None):
2166 _11 = torch.len(size)
2167 if torch.le(_11, 0):
2168 dim_post_expr0 = 1
2169 else:
2170 dim_post_expr0 = _11
2171 min0 = torch.neg(dim_post_expr0)
2172 max0 = torch.sub(dim_post_expr0, 1)
2173 if torch.lt(dim, min0):
2174 _12 = True
2175 else:
2176 _12 = torch.gt(dim, max0)
2177 if torch.__not__(_12):
2178 pass
2179 else:
2180 ops.prim.RaiseException("AssertionError: ")
2181 if torch.lt(dim, 0):
2182 dim1 = torch.add(dim, dim_post_expr0)
2183 out_dim2 = dim1
2184 else:
2185 out_dim2 = dim
2186 out_dim1 = out_dim2
2187 else:
2188 out_dim1 = unchecked_cast(int, out_dim)
2189 out_dim0 : Optional[int] = out_dim1
2190 else:
2191 out_dim0 = out_dim
2192 out_dim = out_dim0
2193 if torch.__is__(out_dim, None):
2194 dim2 = dim
2195 else:
2196 dim2 = unchecked_cast(int, out_dim)
2197 _13 = torch.gt(torch.len(unsqueezed_tensors), 0)
2198 if _13:
2199 pass
2200 else:
2201 ops.prim.RaiseException("AssertionError: ")
2202 not_skipped_tensor: Optional[List[int]] = None
2203 for _14 in range(torch.len(unsqueezed_tensors)):
2204 tensor1 = unsqueezed_tensors[_14]
2205 numel = 1
2206 for _15 in range(torch.len(tensor1)):
2207 elem0 = tensor1[_15]
2208 numel = torch.mul(numel, elem0)
2209 if torch.eq(numel, 0):
2210 _16 = torch.eq(torch.len(tensor1), 1)
2211 else:
2212 _16 = False
2213 if torch.__not__(_16):
2214 not_skipped_tensor0 : Optional[List[int]] = tensor1
2215 else:
2216 not_skipped_tensor0 = not_skipped_tensor
2217 not_skipped_tensor = not_skipped_tensor0
2218 _17 = torch.__is__(not_skipped_tensor, None)
2219 if _17:
2220 _18 = [0]
2221 else:
2222 not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor)
2223 cat_dim_size = 0
2224 for i in range(torch.len(unsqueezed_tensors)):
2225 tensor2 = unsqueezed_tensors[i]
2226 numel0 = 1
2227 for _19 in range(torch.len(tensor2)):
2228 elem1 = tensor2[_19]
2229 numel0 = torch.mul(numel0, elem1)
2230 if torch.eq(numel0, 0):
2231 _20 = torch.eq(torch.len(tensor2), 1)
2232 else:
2233 _20 = False
2234 if torch.__not__(_20):
2235 first_dims = torch.len(not_skipped_tensor1)
2236 second_dims = torch.len(tensor2)
2237 _21 = torch.eq(first_dims, second_dims)
2238 if _21:
2239 pass
2240 else:
2241 ops.prim.RaiseException(_0)
2242 _22 = torch.__range_length(0, first_dims, 1)
2243 for _23 in range(_22):
2244 dim3 = torch.__derive_index(_23, 0, 1)
2245 if torch.ne(dim3, dim2):
2246 _24 = torch.eq(not_skipped_tensor1[dim3], tensor2[dim3])
2247 if _24:
2248 pass
2249 else:
2250 ops.prim.RaiseException(_1)
2251 else:
2252 pass
2253 cat_dim_size1 = torch.add(cat_dim_size, tensor2[dim2])
2254 cat_dim_size0 = cat_dim_size1
2255 else:
2256 cat_dim_size0 = cat_dim_size
2257 cat_dim_size = cat_dim_size0
2258 result_size = annotate(List[int], [])
2259 for _25 in range(torch.len(not_skipped_tensor1)):
2260 elem2 = not_skipped_tensor1[_25]
2261 _26 = torch.append(result_size, elem2)
2262 _27 = torch._set_item(result_size, dim2, cat_dim_size)
2263 _18 = result_size
2264 return _18
2265
2266 )=====")
2267 + std::string(R"=====(def permute(input: List[int],
2268 dims: List[int]) -> List[int]:
2269 _0 = torch.eq(torch.len(input), torch.len(dims))
2270 if _0:
2271 pass
2272 else:
2273 ops.prim.RaiseException("AssertionError: ")
2274 ndim = torch.len(dims)
2275 seen_dims = annotate(List[int], [])
2276 newSizes = annotate(List[int], [])
2277 for i in range(ndim):
2278 _1 = dims[i]
2279 if torch.le(ndim, 0):
2280 dim_post_expr = 1
2281 else:
2282 dim_post_expr = ndim
2283 min = torch.neg(dim_post_expr)
2284 max = torch.sub(dim_post_expr, 1)
2285 if torch.lt(_1, min):
2286 _2 = True
2287 else:
2288 _2 = torch.gt(_1, max)
2289 if torch.__not__(_2):
2290 pass
2291 else:
2292 ops.prim.RaiseException("AssertionError: ")
2293 if torch.lt(_1, 0):
2294 dim = torch.add(_1, dim_post_expr)
2295 else:
2296 dim = _1
2297 _3 = torch.append(seen_dims, dim)
2298 _4 = torch.append(newSizes, input[dim])
2299 for _5 in range(torch.__range_length(1, ndim, 1)):
2300 i0 = torch.__derive_index(_5, 1, 1)
2301 for j in range(i0):
2302 _6 = torch.ne(seen_dims[i0], seen_dims[j])
2303 if _6:
2304 pass
2305 else:
2306 ops.prim.RaiseException("AssertionError: ")
2307 return newSizes
2308
2309 )=====")
2310 + std::string(R"=====(def movedim(self: List[int],
2311 source: List[int],
2312 destination: List[int]) -> List[int]:
2313 self_dim = torch.len(self)
2314 if torch.le(self_dim, 1):
2315 _0 = self
2316 else:
2317 normalized_src = annotate(List[int], [])
2318 normalized_dst = annotate(List[int], [])
2319 for i in range(torch.len(source)):
2320 _1 = source[i]
2321 if torch.le(self_dim, 0):
2322 dim_post_expr = 1
2323 else:
2324 dim_post_expr = self_dim
2325 min = torch.neg(dim_post_expr)
2326 max = torch.sub(dim_post_expr, 1)
2327 if torch.lt(_1, min):
2328 _2 = True
2329 else:
2330 _2 = torch.gt(_1, max)
2331 if torch.__not__(_2):
2332 pass
2333 else:
2334 ops.prim.RaiseException("AssertionError: ")
2335 if torch.lt(_1, 0):
2336 dim = torch.add(_1, dim_post_expr)
2337 else:
2338 dim = _1
2339 _3 = torch.append(normalized_src, dim)
2340 _4 = destination[i]
2341 if torch.le(self_dim, 0):
2342 dim_post_expr0 = 1
2343 else:
2344 dim_post_expr0 = self_dim
2345 min0 = torch.neg(dim_post_expr0)
2346 max0 = torch.sub(dim_post_expr0, 1)
2347 if torch.lt(_4, min0):
2348 _5 = True
2349 else:
2350 _5 = torch.gt(_4, max0)
2351 if torch.__not__(_5):
2352 pass
2353 else:
2354 ops.prim.RaiseException("AssertionError: ")
2355 if torch.lt(_4, 0):
2356 dim0 = torch.add(_4, dim_post_expr0)
2357 else:
2358 dim0 = _4
2359 _6 = torch.append(normalized_dst, dim0)
2360 order = annotate(List[int], [])
2361 for i0 in range(self_dim):
2362 _7 = torch.append(order, -1)
2363 src_dims = annotate(List[int], [])
2364 for i1 in range(self_dim):
2365 _8 = torch.append(src_dims, i1)
2366 dst_dims = annotate(List[int], [])
2367 for i2 in range(self_dim):
2368 _9 = torch.append(dst_dims, i2)
2369 for i3 in range(torch.len(source)):
2370 _10 = normalized_src[i3]
2371 _11 = torch._set_item(order, normalized_dst[i3], _10)
2372 _12 = torch._set_item(src_dims, normalized_src[i3], -1)
2373 _13 = torch._set_item(dst_dims, normalized_dst[i3], -1)
2374 source_dims = annotate(List[int], [])
2375 destination_dims = annotate(List[int], [])
2376 for _14 in range(torch.len(src_dims)):
2377 ele = src_dims[_14]
2378 if torch.ne(ele, -1):
2379 _15 = torch.append(source_dims, ele)
2380 else:
2381 pass
2382 for _16 in range(torch.len(dst_dims)):
2383 ele0 = dst_dims[_16]
2384 if torch.ne(ele0, -1):
2385 _17 = torch.append(destination_dims, ele0)
2386 else:
2387 pass
2388 rest_dim = torch.sub(self_dim, torch.len(source))
2389 for i4 in range(rest_dim):
2390 _18 = source_dims[i4]
2391 _19 = torch._set_item(order, destination_dims[i4], _18)
2392 _20 = torch.eq(torch.len(self), torch.len(order))
2393 if _20:
2394 pass
2395 else:
2396 ops.prim.RaiseException("AssertionError: ")
2397 ndim = torch.len(order)
2398 seen_dims = annotate(List[int], [])
2399 newSizes = annotate(List[int], [])
2400 for i5 in range(ndim):
2401 _21 = order[i5]
2402 if torch.le(ndim, 0):
2403 dim_post_expr1 = 1
2404 else:
2405 dim_post_expr1 = ndim
2406 min1 = torch.neg(dim_post_expr1)
2407 max1 = torch.sub(dim_post_expr1, 1)
2408 if torch.lt(_21, min1):
2409 _22 = True
2410 else:
2411 _22 = torch.gt(_21, max1)
2412 if torch.__not__(_22):
2413 pass
2414 else:
2415 ops.prim.RaiseException("AssertionError: ")
2416 if torch.lt(_21, 0):
2417 dim1 = torch.add(_21, dim_post_expr1)
2418 else:
2419 dim1 = _21
2420 _23 = torch.append(seen_dims, dim1)
2421 _24 = torch.append(newSizes, self[dim1])
2422 for _25 in range(torch.__range_length(1, ndim, 1)):
2423 i6 = torch.__derive_index(_25, 1, 1)
2424 for j in range(i6):
2425 _26 = torch.ne(seen_dims[i6], seen_dims[j])
2426 if _26:
2427 pass
2428 else:
2429 ops.prim.RaiseException("AssertionError: ")
2430 _0 = newSizes
2431 return _0
2432
2433 )=====")
2434 + std::string(R"=====(def view(self: List[int],
2435 sizes: List[int]) -> List[int]:
2436 _0 = "AssertionError: only one dimension can be inferred"
2437 _1 = "AssertionError: invalid shape dimensions"
2438 numel = 1
2439 for _2 in range(torch.len(self)):
2440 elem = self[_2]
2441 numel = torch.mul(numel, elem)
2442 _3 = uninitialized(int)
2443 newsize = 1
2444 infer_dim: Optional[int] = None
2445 for dim in range(torch.len(sizes)):
2446 if torch.eq(sizes[dim], -1):
2447 if torch.__isnot__(infer_dim, None):
2448 ops.prim.RaiseException(_0)
2449 else:
2450 pass
2451 newsize0, infer_dim0 = newsize, dim
2452 else:
2453 if torch.ge(sizes[dim], 0):
2454 newsize1 = torch.mul(newsize, sizes[dim])
2455 else:
2456 ops.prim.RaiseException(_1)
2457 newsize1 = _3
2458 newsize0, infer_dim0 = newsize1, infer_dim
2459 newsize, infer_dim = newsize0, infer_dim0
2460 if torch.eq(numel, newsize):
2461 _4, infer_dim1 = True, infer_dim
2462 else:
2463 if torch.__isnot__(infer_dim, None):
2464 infer_dim3 = unchecked_cast(int, infer_dim)
2465 _5, infer_dim2 = torch.gt(newsize, 0), infer_dim3
2466 else:
2467 _5, infer_dim2 = False, infer_dim
2468 if _5:
2469 infer_dim5 = unchecked_cast(int, infer_dim2)
2470 _7 = torch.eq(torch.remainder(numel, newsize), 0)
2471 _6, infer_dim4 = _7, infer_dim5
2472 else:
2473 _6, infer_dim4 = False, infer_dim2
2474 _4, infer_dim1 = _6, infer_dim4
2475 if torch.__not__(_4):
2476 ops.prim.RaiseException("AssertionError: invalid shape")
2477 else:
2478 pass
2479 out = annotate(List[int], [])
2480 for _8 in range(torch.len(sizes)):
2481 elem0 = sizes[_8]
2482 _9 = torch.append(out, elem0)
2483 if torch.__isnot__(infer_dim1, None):
2484 infer_dim6 = unchecked_cast(int, infer_dim1)
2485 _10 = torch._set_item(out, infer_dim6, torch.floordiv(numel, newsize))
2486 else:
2487 pass
2488 return out
2489
2490 )=====")
2491 + std::string(R"=====(def expand(self: List[int],
2492 sizes: List[int]) -> List[int]:
2493 _0 = torch.ge(torch.len(sizes), torch.len(self))
2494 if _0:
2495 pass
2496 else:
2497 ops.prim.RaiseException("AssertionError: ")
2498 ndim = torch.len(sizes)
2499 tensor_dim = torch.len(self)
2500 if torch.eq(ndim, 0):
2501 out = annotate(List[int], [])
2502 for _2 in range(torch.len(sizes)):
2503 elem = sizes[_2]
2504 _3 = torch.append(out, elem)
2505 _1 = out
2506 else:
2507 out0 = annotate(List[int], [])
2508 for i in range(ndim):
2509 offset = torch.sub(torch.sub(ndim, 1), i)
2510 dim = torch.sub(torch.sub(tensor_dim, 1), offset)
2511 if torch.ge(dim, 0):
2512 size = self[dim]
2513 else:
2514 size = 1
2515 targetSize = sizes[i]
2516 if torch.eq(targetSize, -1):
2517 if torch.ge(dim, 0):
2518 pass
2519 else:
2520 ops.prim.RaiseException("AssertionError: ")
2521 targetSize0 = size
2522 else:
2523 targetSize0 = targetSize
2524 if torch.ne(size, targetSize0):
2525 if torch.eq(size, 1):
2526 pass
2527 else:
2528 ops.prim.RaiseException("AssertionError: ")
2529 size0 = targetSize0
2530 else:
2531 size0 = size
2532 _4 = torch.append(out0, size0)
2533 _1 = out0
2534 return _1
2535
2536 )=====")
2537 + std::string(R"=====(def expand_one_unused(self: List[int],
2538 sizes: List[int],
2539 inp0: Any) -> List[int]:
2540 _0 = torch.ge(torch.len(sizes), torch.len(self))
2541 if _0:
2542 pass
2543 else:
2544 ops.prim.RaiseException("AssertionError: ")
2545 ndim = torch.len(sizes)
2546 tensor_dim = torch.len(self)
2547 if torch.eq(ndim, 0):
2548 out = annotate(List[int], [])
2549 for _2 in range(torch.len(sizes)):
2550 elem = sizes[_2]
2551 _3 = torch.append(out, elem)
2552 _1 = out
2553 else:
2554 out0 = annotate(List[int], [])
2555 for i in range(ndim):
2556 offset = torch.sub(torch.sub(ndim, 1), i)
2557 dim = torch.sub(torch.sub(tensor_dim, 1), offset)
2558 if torch.ge(dim, 0):
2559 size = self[dim]
2560 else:
2561 size = 1
2562 targetSize = sizes[i]
2563 if torch.eq(targetSize, -1):
2564 if torch.ge(dim, 0):
2565 pass
2566 else:
2567 ops.prim.RaiseException("AssertionError: ")
2568 targetSize0 = size
2569 else:
2570 targetSize0 = targetSize
2571 if torch.ne(size, targetSize0):
2572 if torch.eq(size, 1):
2573 pass
2574 else:
2575 ops.prim.RaiseException("AssertionError: ")
2576 size0 = targetSize0
2577 else:
2578 size0 = size
2579 _4 = torch.append(out0, size0)
2580 _1 = out0
2581 return _1
2582
2583 )=====")
2584 + std::string(R"=====(def sum_mean_dim(self: List[int],
2585 opt_dims: Optional[List[int]],
2586 keep_dim: bool,
2587 dt: Any) -> List[int]:
2588 out = annotate(List[int], [])
2589 if torch.__is__(opt_dims, None):
2590 _0, opt_dims0 = True, opt_dims
2591 else:
2592 opt_dims1 = unchecked_cast(List[int], opt_dims)
2593 _0, opt_dims0 = torch.eq(torch.len(opt_dims1), 0), opt_dims1
2594 if _0:
2595 _1 = torch.len(self)
2596 dims0 = annotate(List[int], [])
2597 for _2 in range(_1):
2598 _3 = torch.append(dims0, _2)
2599 dims = dims0
2600 else:
2601 opt_dims2 = unchecked_cast(List[int], opt_dims0)
2602 dims = opt_dims2
2603 for idx in range(torch.len(self)):
2604 is_mean_dim = False
2605 for _4 in range(torch.len(dims)):
2606 reduce_dim = dims[_4]
2607 _5 = torch.len(self)
2608 if torch.le(_5, 0):
2609 dim_post_expr = 1
2610 else:
2611 dim_post_expr = _5
2612 min = torch.neg(dim_post_expr)
2613 max = torch.sub(dim_post_expr, 1)
2614 if torch.lt(reduce_dim, min):
2615 _6 = True
2616 else:
2617 _6 = torch.gt(reduce_dim, max)
2618 if torch.__not__(_6):
2619 pass
2620 else:
2621 ops.prim.RaiseException("AssertionError: ")
2622 if torch.lt(reduce_dim, 0):
2623 dim0 = torch.add(reduce_dim, dim_post_expr)
2624 dim = dim0
2625 else:
2626 dim = reduce_dim
2627 if torch.eq(idx, dim):
2628 is_mean_dim0 = True
2629 else:
2630 is_mean_dim0 = is_mean_dim
2631 is_mean_dim = is_mean_dim0
2632 if is_mean_dim:
2633 if keep_dim:
2634 _7 = torch.append(out, 1)
2635 else:
2636 pass
2637 else:
2638 _8 = torch.append(out, self[idx])
2639 return out
2640
2641 )=====")
2642 + std::string(R"=====(def max_dim(self: List[int],
2643 dim: int,
2644 keep_dim: bool) -> Tuple[List[int], List[int]]:
2645 _0 = [dim]
2646 out = annotate(List[int], [])
2647 for idx in range(torch.len(self)):
2648 is_mean_dim = False
2649 for _1 in range(torch.len(_0)):
2650 reduce_dim = _0[_1]
2651 _2 = torch.len(self)
2652 if torch.le(_2, 0):
2653 dim_post_expr = 1
2654 else:
2655 dim_post_expr = _2
2656 min = torch.neg(dim_post_expr)
2657 max = torch.sub(dim_post_expr, 1)
2658 if torch.lt(reduce_dim, min):
2659 _3 = True
2660 else:
2661 _3 = torch.gt(reduce_dim, max)
2662 if torch.__not__(_3):
2663 pass
2664 else:
2665 ops.prim.RaiseException("AssertionError: ")
2666 if torch.lt(reduce_dim, 0):
2667 dim1 = torch.add(reduce_dim, dim_post_expr)
2668 dim0 = dim1
2669 else:
2670 dim0 = reduce_dim
2671 if torch.eq(idx, dim0):
2672 is_mean_dim0 = True
2673 else:
2674 is_mean_dim0 = is_mean_dim
2675 is_mean_dim = is_mean_dim0
2676 if is_mean_dim:
2677 if keep_dim:
2678 _4 = torch.append(out, 1)
2679 else:
2680 pass
2681 else:
2682 _5 = torch.append(out, self[idx])
2683 return (out, out)
2684
2685 )=====")
2686 + std::string(R"=====(def addmm(self: List[int],
2687 mat1: List[int],
2688 mat2: List[int],
2689 beta: Any,
2690 alpha: Any) -> List[int]:
2691 _0 = "AssertionError: self must be a matrix"
2692 _1 = "AssertionError: mat2 must be a matrix"
2693 _2 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
2694 if torch.eq(torch.len(mat1), 2):
2695 pass
2696 else:
2697 ops.prim.RaiseException(_0)
2698 if torch.eq(torch.len(mat2), 2):
2699 pass
2700 else:
2701 ops.prim.RaiseException(_1)
2702 if torch.eq(mat1[1], mat2[0]):
2703 pass
2704 else:
2705 ops.prim.RaiseException("AssertionError: ")
2706 _3 = [mat1[0], mat2[1]]
2707 dimsA = torch.len(self)
2708 ndim = ops.prim.max(dimsA, 2)
2709 expandedSizes = annotate(List[int], [])
2710 for i in range(ndim):
2711 offset = torch.sub(torch.sub(ndim, 1), i)
2712 dimA = torch.sub(torch.sub(dimsA, 1), offset)
2713 dimB = torch.sub(1, offset)
2714 if torch.ge(dimA, 0):
2715 sizeA = self[dimA]
2716 else:
2717 sizeA = 1
2718 if torch.ge(dimB, 0):
2719 sizeB = _3[dimB]
2720 else:
2721 sizeB = 1
2722 if torch.ne(sizeA, sizeB):
2723 _4 = torch.ne(sizeA, 1)
2724 else:
2725 _4 = False
2726 if _4:
2727 _5 = torch.ne(sizeB, 1)
2728 else:
2729 _5 = False
2730 if _5:
2731 _6 = torch.add("AssertionError: ", torch.format(_2, sizeA, sizeB, i))
2732 ops.prim.RaiseException(_6)
2733 else:
2734 pass
2735 if torch.eq(sizeA, 1):
2736 _7 = sizeB
2737 else:
2738 _7 = sizeA
2739 _8 = torch.append(expandedSizes, _7)
2740 return expandedSizes
2741
2742 )=====")
2743 + std::string(R"=====(def upsample_nearest2d(input: List[int],
2744 output_size: Optional[List[int]],
2745 scale_factors: Optional[List[float]]) -> List[int]:
2746 _0 = "AssertionError: Either output_size or scale_factors must be presented"
2747 _1 = "AssertionError: Must specify exactly one of output_size and scale_factors"
2748 _2 = uninitialized(Optional[List[float]])
2749 out = annotate(List[int], [])
2750 _3 = torch.append(out, input[0])
2751 _4 = torch.append(out, input[1])
2752 if torch.__is__(scale_factors, None):
2753 _5, scale_factors0 = torch.__is__(output_size, None), scale_factors
2754 else:
2755 scale_factors1 = unchecked_cast(List[float], scale_factors)
2756 _5, scale_factors0 = False, scale_factors1
2757 if _5:
2758 ops.prim.RaiseException(_0)
2759 else:
2760 pass
2761 if torch.__isnot__(output_size, None):
2762 output_size1 = unchecked_cast(List[int], output_size)
2763 if torch.__is__(scale_factors0, None):
2764 scale_factors3 : Optional[List[float]] = scale_factors0
2765 else:
2766 ops.prim.RaiseException(_1)
2767 scale_factors3 = _2
2768 _6 = torch.eq(torch.len(output_size1), 2)
2769 if _6:
2770 pass
2771 else:
2772 ops.prim.RaiseException("AssertionError: ")
2773 _7 = torch.append(out, output_size1[0])
2774 _8 = torch.append(out, output_size1[1])
2775 scale_factors2, output_size0 = scale_factors3, output_size1
2776 else:
2777 scale_factors2, output_size0 = scale_factors0, output_size
2778 if torch.__isnot__(scale_factors2, None):
2779 scale_factors4 = unchecked_cast(List[float], scale_factors2)
2780 if torch.__is__(output_size0, None):
2781 pass
2782 else:
2783 ops.prim.RaiseException(_1)
2784 _9 = torch.eq(torch.len(scale_factors4), 2)
2785 if _9:
2786 pass
2787 else:
2788 ops.prim.RaiseException("AssertionError: ")
2789 _10 = torch.mul(input[2], scale_factors4[0])
2790 _11 = torch.append(out, int(_10))
2791 _12 = torch.mul(input[3], scale_factors4[1])
2792 _13 = torch.append(out, int(_12))
2793 else:
2794 pass
2795 return out
2796
2797 )=====")
2798 + std::string(R"=====(def broadcast(a: List[int],
2799 b: List[int]) -> List[int]:
2800 _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
2801 dimsA = torch.len(a)
2802 dimsB = torch.len(b)
2803 ndim = ops.prim.max(dimsA, dimsB)
2804 expandedSizes = annotate(List[int], [])
2805 for i in range(ndim):
2806 offset = torch.sub(torch.sub(ndim, 1), i)
2807 dimA = torch.sub(torch.sub(dimsA, 1), offset)
2808 dimB = torch.sub(torch.sub(dimsB, 1), offset)
2809 if torch.ge(dimA, 0):
2810 sizeA = a[dimA]
2811 else:
2812 sizeA = 1
2813 if torch.ge(dimB, 0):
2814 sizeB = b[dimB]
2815 else:
2816 sizeB = 1
2817 if torch.ne(sizeA, sizeB):
2818 _1 = torch.ne(sizeA, 1)
2819 else:
2820 _1 = False
2821 if _1:
2822 _2 = torch.ne(sizeB, 1)
2823 else:
2824 _2 = False
2825 if _2:
2826 _3 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
2827 ops.prim.RaiseException(_3)
2828 else:
2829 pass
2830 if torch.eq(sizeA, 1):
2831 _4 = sizeB
2832 else:
2833 _4 = sizeA
2834 _5 = torch.append(expandedSizes, _4)
2835 return expandedSizes
2836
2837 )=====")
2838 + std::string(R"=====(def argmax(self: List[int],
2839 dim: Optional[int]=None,
2840 keepdim: bool=False) -> List[int]:
2841 if torch.__is__(dim, None):
2842 _0 = annotate(List[int], [])
2843 else:
2844 dim0 = unchecked_cast(int, dim)
2845 _1 = torch.len(self)
2846 if torch.le(_1, 0):
2847 dim_post_expr = 1
2848 else:
2849 dim_post_expr = _1
2850 min = torch.neg(dim_post_expr)
2851 max = torch.sub(dim_post_expr, 1)
2852 if torch.lt(dim0, min):
2853 _2 = True
2854 else:
2855 _2 = torch.gt(dim0, max)
2856 if torch.__not__(_2):
2857 pass
2858 else:
2859 ops.prim.RaiseException("AssertionError: ")
2860 if torch.lt(dim0, 0):
2861 dim1 = torch.add(dim0, dim_post_expr)
2862 else:
2863 dim1 = dim0
2864 out = annotate(List[int], [])
2865 _3 = [9223372036854775807, torch.len(self)]
2866 for i in range(ops.prim.min(_3)):
2867 self_dim = self[i]
2868 if torch.eq(i, dim1):
2869 if keepdim:
2870 _4 = torch.append(out, 1)
2871 else:
2872 pass
2873 else:
2874 _5 = torch.append(out, self_dim)
2875 _0 = out
2876 return _0
2877
2878 def bmm(self: List[int],
2879 mat2: List[int]) -> List[int]:
2880 _0 = "AssertionError: bmm only supports 3D tensors"
2881 _1 = "AssertionError: mismatching batch dimension"
2882 _2 = "AssertionError: mismatching contracting dimension"
2883 if torch.eq(torch.len(self), 3):
2884 pass
2885 else:
2886 ops.prim.RaiseException(_0)
2887 if torch.eq(torch.len(mat2), 3):
2888 pass
2889 else:
2890 ops.prim.RaiseException(_0)
2891 if torch.eq(self[0], mat2[0]):
2892 pass
2893 else:
2894 ops.prim.RaiseException(_1)
2895 if torch.eq(self[2], mat2[1]):
2896 pass
2897 else:
2898 ops.prim.RaiseException(_2)
2899 return [self[0], self[1], mat2[2]]
2900
2901 def _shape_as_tensor(self: List[int]) -> List[int]:
2902 return [torch.len(self)]
2903
2904 )=====")
2905 + std::string(R"=====(def topk(self: List[int],
2906 k: int,
2907 dim: int=-1) -> Tuple[List[int], List[int]]:
2908 _0 = "k ({}) is too big for dimension {} of size {}"
2909 if torch.eq(torch.len(self), 0):
2910 result = annotate(List[int], [])
2911 else:
2912 if torch.le(k, self[dim]):
2913 pass
2914 else:
2915 _1 = torch.format(_0, k, dim, self[dim])
2916 ops.prim.RaiseException(torch.add("AssertionError: ", _1))
2917 result0 = annotate(List[int], [])
2918 for _2 in range(torch.len(self)):
2919 elem = self[_2]
2920 _3 = torch.append(result0, elem)
2921 _4 = torch._set_item(result0, dim, k)
2922 result = result0
2923 return (result, result)
2924
2925 def nll_loss_forward(self: List[int],
2926 target: List[int],
2927 weight: Optional[List[int]],
2928 reduction: int) -> Tuple[List[int], List[int]]:
2929 self_dim = torch.len(self)
2930 target_dim = torch.len(target)
2931 if torch.lt(0, self_dim):
2932 _0 = torch.le(self_dim, 2)
2933 else:
2934 _0 = False
2935 if _0:
2936 pass
2937 else:
2938 ops.prim.RaiseException("AssertionError: ")
2939 if torch.le(target_dim, 1):
2940 pass
2941 else:
2942 ops.prim.RaiseException("AssertionError: ")
2943 if torch.eq(self_dim, 1):
2944 no_batch_dim = torch.eq(target_dim, 0)
2945 else:
2946 no_batch_dim = False
2947 if no_batch_dim:
2948 _1 = True
2949 else:
2950 _1 = torch.eq(self[0], target[0])
2951 if _1:
2952 pass
2953 else:
2954 ops.prim.RaiseException("AssertionError: ")
2955 n_classes = self[-1]
2956 if torch.__is__(weight, None):
2957 _2 = True
2958 else:
2959 weight0 = unchecked_cast(List[int], weight)
2960 if torch.eq(torch.len(weight0), 1):
2961 _3 = torch.eq(weight0[0], n_classes)
2962 else:
2963 _3 = False
2964 _2 = _3
2965 if _2:
2966 pass
2967 else:
2968 ops.prim.RaiseException("AssertionError: ")
2969 if torch.eq(reduction, 0):
2970 _4 = torch.eq(self_dim, 2)
2971 else:
2972 _4 = False
2973 if _4:
2974 reduction_shape = [self[0]]
2975 else:
2976 reduction_shape = annotate(List[int], [])
2977 _5 = (reduction_shape, annotate(List[int], []))
2978 return _5
2979
2980 )=====")
2981 + std::string(R"=====(def native_layer_norm(input: List[int],
2982 normalized_shape: List[int]) -> Tuple[List[int], List[int], List[int]]:
2983 reduction_shape = annotate(List[int], [])
2984 num_unreduced_dimensions = torch.sub(torch.len(input), torch.len(normalized_shape))
2985 if torch.ge(num_unreduced_dimensions, 0):
2986 pass
2987 else:
2988 ops.prim.RaiseException("AssertionError: ")
2989 for i in range(num_unreduced_dimensions):
2990 _0 = torch.append(reduction_shape, input[i])
2991 _1 = torch.__range_length(num_unreduced_dimensions, torch.len(input), 1)
2992 for _2 in range(_1):
2993 _3 = torch.append(reduction_shape, 1)
2994 out = annotate(List[int], [])
2995 for _4 in range(torch.len(input)):
2996 elem = input[_4]
2997 _5 = torch.append(out, elem)
2998 _6 = (out, reduction_shape, reduction_shape)
2999 return _6
3000
3001 def native_batch_norm(input: List[int],
3002 weight: Optional[List[int]],
3003 bias: Optional[List[int]],
3004 running_mean: Optional[List[int]],
3005 running_var: Optional[List[int]],
3006 training: bool) -> Tuple[List[int], List[int], List[int]]:
3007 if training:
3008 _size = [input[1]]
3009 else:
3010 _size = [0]
3011 out = annotate(List[int], [])
3012 for _0 in range(torch.len(input)):
3013 elem = input[_0]
3014 _1 = torch.append(out, elem)
3015 return (out, _size, _size)
3016
3017 def _batch_norm_with_update(input: List[int],
3018 weight: Optional[List[int]],
3019 bias: Optional[List[int]],
3020 running_mean: Optional[List[int]],
3021 running_var: Optional[List[int]]) -> Tuple[List[int], List[int], List[int], List[int]]:
3022 _size = [input[1]]
3023 out = annotate(List[int], [])
3024 for _0 in range(torch.len(input)):
3025 elem = input[_0]
3026 _1 = torch.append(out, elem)
3027 return (out, _size, _size, [0])
3028
3029 )=====")
3030 + std::string(R"=====(def cross_entropy_loss(self: List[int],
3031 target: List[int],
3032 weight: Optional[List[int]]=None,
3033 reduction: int=1,
3034 ignore_index: int=-100,
3035 label_smoothing: float=0.) -> List[int]:
3036 self_dim = torch.len(self)
3037 target_dim = torch.len(target)
3038 if torch.lt(0, self_dim):
3039 _0 = torch.le(self_dim, 2)
3040 else:
3041 _0 = False
3042 if _0:
3043 pass
3044 else:
3045 ops.prim.RaiseException("AssertionError: ")
3046 if torch.le(target_dim, 1):
3047 pass
3048 else:
3049 ops.prim.RaiseException("AssertionError: ")
3050 if torch.eq(self_dim, 1):
3051 no_batch_dim = torch.eq(target_dim, 0)
3052 else:
3053 no_batch_dim = False
3054 if no_batch_dim:
3055 _1 = True
3056 else:
3057 _1 = torch.eq(self[0], target[0])
3058 if _1:
3059 pass
3060 else:
3061 ops.prim.RaiseException("AssertionError: ")
3062 n_classes = self[-1]
3063 if torch.__is__(weight, None):
3064 _2 = True
3065 else:
3066 weight0 = unchecked_cast(List[int], weight)
3067 if torch.eq(torch.len(weight0), 1):
3068 _3 = torch.eq(weight0[0], n_classes)
3069 else:
3070 _3 = False
3071 _2 = _3
3072 if _2:
3073 pass
3074 else:
3075 ops.prim.RaiseException("AssertionError: ")
3076 if torch.eq(reduction, 0):
3077 _4 = torch.eq(self_dim, 2)
3078 else:
3079 _4 = False
3080 if _4:
3081 reduction_shape = [self[0]]
3082 else:
3083 reduction_shape = annotate(List[int], [])
3084 _5 = (reduction_shape, annotate(List[int], []))
3085 return (_5)[0]
3086
3087 )=====")
3088 + std::string(R"=====(def broadcast_three(a: List[int],
3089 b: List[int],
3090 c: List[int]) -> List[int]:
3091 _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3092 _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3093 dimsA = torch.len(a)
3094 dimsB = torch.len(b)
3095 ndim = ops.prim.max(dimsA, dimsB)
3096 expandedSizes = annotate(List[int], [])
3097 for i in range(ndim):
3098 offset = torch.sub(torch.sub(ndim, 1), i)
3099 dimA = torch.sub(torch.sub(dimsA, 1), offset)
3100 dimB = torch.sub(torch.sub(dimsB, 1), offset)
3101 if torch.ge(dimA, 0):
3102 sizeA = a[dimA]
3103 else:
3104 sizeA = 1
3105 if torch.ge(dimB, 0):
3106 sizeB = b[dimB]
3107 else:
3108 sizeB = 1
3109 if torch.ne(sizeA, sizeB):
3110 _2 = torch.ne(sizeA, 1)
3111 else:
3112 _2 = False
3113 if _2:
3114 _3 = torch.ne(sizeB, 1)
3115 else:
3116 _3 = False
3117 if _3:
3118 _4 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
3119 ops.prim.RaiseException(_4)
3120 else:
3121 pass
3122 if torch.eq(sizeA, 1):
3123 _5 = sizeB
3124 else:
3125 _5 = sizeA
3126 _6 = torch.append(expandedSizes, _5)
3127 dimsA0 = torch.len(expandedSizes)
3128 dimsB0 = torch.len(c)
3129 ndim0 = ops.prim.max(dimsA0, dimsB0)
3130 expandedSizes0 = annotate(List[int], [])
3131 for i0 in range(ndim0):
3132 offset0 = torch.sub(torch.sub(ndim0, 1), i0)
3133 dimA0 = torch.sub(torch.sub(dimsA0, 1), offset0)
3134 dimB0 = torch.sub(torch.sub(dimsB0, 1), offset0)
3135 if torch.ge(dimA0, 0):
3136 sizeA0 = expandedSizes[dimA0]
3137 else:
3138 sizeA0 = 1
3139 if torch.ge(dimB0, 0):
3140 sizeB0 = c[dimB0]
3141 else:
3142 sizeB0 = 1
3143 if torch.ne(sizeA0, sizeB0):
3144 _7 = torch.ne(sizeA0, 1)
3145 else:
3146 _7 = False
3147 if _7:
3148 _8 = torch.ne(sizeB0, 1)
3149 else:
3150 _8 = False
3151 if _8:
3152 _9 = torch.format(_1, sizeA0, sizeB0, i0)
3153 ops.prim.RaiseException(torch.add("AssertionError: ", _9))
3154 else:
3155 pass
3156 if torch.eq(sizeA0, 1):
3157 _10 = sizeB0
3158 else:
3159 _10 = sizeA0
3160 _11 = torch.append(expandedSizes0, _10)
3161 return expandedSizes0
3162
3163 )=====")
3164 + std::string(R"=====(def broadcast_one_three(a: List[int],
3165 b: Any,
3166 c: List[int]) -> List[int]:
3167 _0 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3168 dimsA = torch.len(a)
3169 dimsB = torch.len(c)
3170 ndim = ops.prim.max(dimsA, dimsB)
3171 expandedSizes = annotate(List[int], [])
3172 for i in range(ndim):
3173 offset = torch.sub(torch.sub(ndim, 1), i)
3174 dimA = torch.sub(torch.sub(dimsA, 1), offset)
3175 dimB = torch.sub(torch.sub(dimsB, 1), offset)
3176 if torch.ge(dimA, 0):
3177 sizeA = a[dimA]
3178 else:
3179 sizeA = 1
3180 if torch.ge(dimB, 0):
3181 sizeB = c[dimB]
3182 else:
3183 sizeB = 1
3184 if torch.ne(sizeA, sizeB):
3185 _1 = torch.ne(sizeA, 1)
3186 else:
3187 _1 = False
3188 if _1:
3189 _2 = torch.ne(sizeB, 1)
3190 else:
3191 _2 = False
3192 if _2:
3193 _3 = torch.add("AssertionError: ", torch.format(_0, sizeA, sizeB, i))
3194 ops.prim.RaiseException(_3)
3195 else:
3196 pass
3197 if torch.eq(sizeA, 1):
3198 _4 = sizeB
3199 else:
3200 _4 = sizeA
3201 _5 = torch.append(expandedSizes, _4)
3202 return expandedSizes
3203
3204 )=====")
3205 + std::string(R"=====(def broadcast_inplace(a: List[int],
3206 b: List[int]) -> List[int]:
3207 _0 = "The dims of tensor b ({}) must be less than or equal tothe dims of tensor a ({}) "
3208 _1 = "The size of tensor a {} must match the size of tensor b ({}) at non-singleton dimension {}"
3209 dimsA = torch.len(a)
3210 dimsB = torch.len(b)
3211 if torch.gt(dimsB, dimsA):
3212 _2 = torch.add("AssertionError: ", torch.format(_0, dimsB, dimsA))
3213 ops.prim.RaiseException(_2)
3214 else:
3215 pass
3216 for dimA in range(dimsA):
3217 dimB = torch.add(torch.sub(dimsB, dimsA), dimA)
3218 sizeA = a[dimA]
3219 if torch.ge(dimB, 0):
3220 sizeB = b[dimB]
3221 else:
3222 sizeB = 1
3223 if torch.ne(sizeA, sizeB):
3224 _3 = torch.ne(sizeB, 1)
3225 else:
3226 _3 = False
3227 if _3:
3228 _4 = torch.format(_1, sizeA, sizeB, dimA)
3229 ops.prim.RaiseException(torch.add("AssertionError: ", _4))
3230 else:
3231 pass
3232 out = annotate(List[int], [])
3233 for _5 in range(torch.len(a)):
3234 elem = a[_5]
3235 _6 = torch.append(out, elem)
3236 return out
3237
3238 def nonzero_lower_bound(input: List[int]) -> List[int]:
3239 return [0, torch.len(input)]
3240
3241 def nonzero_upper_bound(input: List[int]) -> List[int]:
3242 numel = 1
3243 for _0 in range(torch.len(input)):
3244 elem = input[_0]
3245 numel = torch.mul(numel, elem)
3246 return [numel, torch.len(input)]
3247
3248 )=====")
3249 ;
3250
3251
GetSerializedShapeFunctions()3252 const std::string& GetSerializedShapeFunctions() {
3253 return shape_funcs;
3254 }
3255
3256
GetShapeFunctionMappings()3257 const OperatorMap<std::string>& GetShapeFunctionMappings() {
3258 static const OperatorMap<std::string> shape_mappings {
3259 {"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", "unary"},
3260 {"aten::rsub.Tensor(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary"},
3261 {"aten::dropout(Tensor input, float p, bool train) -> Tensor", "unary"},
3262 {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "adaptive_avg_pool2d"},
3263 {"prim::NumToTensor.Scalar(Scalar a) -> Tensor", "zero_dim_tensor"},
3264 {"prim::NumToTensor.bool(bool a) -> Tensor", "zero_dim_tensor"},
3265 {"aten::zeros(int[] size, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "unary"},
3266 {"aten::to.dtype(Tensor(a) self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))", "unary"},
3267 {"aten::arange(Scalar end, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor)", "arange_end"},
3268 {"aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start"},
3269 {"aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "arange_start_step"},
3270 {"aten::squeeze(Tensor(a) self) -> Tensor(a)", "squeeze_nodim"},
3271 {"aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "squeeze"},
3272 {"aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", "squeeze_dims"},
3273 {"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "unsqueeze"},
3274 {"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", "slice"},
3275 {"aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)", "select"},
3276 {"aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "index_select"},
3277 {"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", "unary"},
3278 {"aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "unary"},
3279 {"aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", "unary"},
3280 {"aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "unary"},
3281 {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "embedding"},
3282 {"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"},
3283 {"aten::dot(Tensor self, Tensor tensor) -> Tensor", "dot"},
3284 {"aten::mv(Tensor self, Tensor vec) -> Tensor", "mv"},
3285 {"aten::matmul(Tensor self, Tensor other) -> Tensor", "matmul"},
3286 {"aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "linear"},
3287 {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "max_pool2d"},
3288 {"aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "max_pool2d_with_indices"},
3289 {"aten::t(Tensor(a) self) -> Tensor(a)", "t"},
3290 {"aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "transpose"},
3291 {"aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor", "conv1d"},
3292 {"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "conv2d"},
3293 {"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "batch_norm"},
3294 {"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
3295 {"aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, int[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "conv_backwards"},
3296 {"aten::convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", "conv_forwards"},
3297 {"aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "_conv_forwards"},
3298 {"aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", "conv_transpose2d_input"},
3299 {"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
3300 {"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
3301 {"aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "stack"},
3302 {"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
3303 {"aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "movedim"},
3304 {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
3305 {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
3306 {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
3307 {"aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"},
3308 {"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_mean_dim"},
3309 {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
3310 {"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
3311 {"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
3312 {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"},
3313 {"aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)", "upsample_nearest2d"},
3314 {"aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", "unary"},
3315 {"aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", "unary"},
3316 {"aten::dequantize(Tensor self) -> Tensor", "unary"},
3317 {"quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc", "broadcast"},
3318 {"aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "argmax"},
3319 {"aten::bmm(Tensor self, Tensor mat2) -> Tensor", "bmm"},
3320 {"aten::_shape_as_tensor(Tensor self) -> Tensor", "_shape_as_tensor"},
3321 {"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", "topk"},
3322 {"aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight)", "nll_loss_forward"},
3323 {"aten::native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "native_layer_norm"},
3324 {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3325 {"aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3326 {"aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "native_batch_norm"},
3327 {"aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "_batch_norm_with_update"},
3328 {"aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "cross_entropy_loss"},
3329 {"aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "broadcast_three"},
3330 {"aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "broadcast_one_three"},
3331 {"aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "broadcast_inplace"},
3332 };
3333
3334 return shape_mappings;
3335 }
3336
GetBoundedShapeMappings()3337 const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() {
3338 static const OperatorMap<std::pair<std::string, std::string>> shape_mappings {
3339 {"aten::nonzero(Tensor self) -> (Tensor)", {"nonzero_lower_bound", "nonzero_upper_bound"}},
3340 };
3341
3342 return shape_mappings;
3343 }
3344
3345 // clang-format on
3346
3347 } // namespace torch::jit
3348