xref: /aosp_15_r20/external/pytorch/test/test_dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import contextlib
4import copy
5import itertools
6import math
7import operator
8import unittest
9
10import numpy as np
11import sympy
12
13import torch
14import torch.fx
15import torch.nn.functional as F
16from torch import sym_int, SymBool, SymFloat, SymInt
17from torch._C import _disabled_torch_function_impl
18from torch.fx.experimental import sym_node
19from torch.fx.experimental.proxy_tensor import make_fx
20from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
21from torch.fx.experimental.symbolic_shapes import (
22    _constrain_range_for_size,
23    DimConstraints,
24    DimDynamic,
25    expect_true,
26    guard_bool,
27    guard_float,
28    guard_int,
29    GuardOnDataDependentSymNode,
30    hint_int,
31    is_symbolic,
32    ShapeEnv,
33    StatelessSymbolicContext,
34    statically_known_true,
35)
36from torch.testing._internal.common_utils import (
37    instantiate_parametrized_tests,
38    parametrize,
39    run_tests,
40    skipIfTorchDynamo,
41    TestCase,
42)
43from torch.utils import _pytree as pytree
44from torch.utils._python_dispatch import TorchDispatchMode
45from torch.utils._sympy.functions import (
46    FloorDiv,
47    IsNonOverlappingAndDenseIndicator,
48    Mod,
49)
50
51
52aten = torch.ops.aten
53
54meta_funcs = {}
55
56
57def register_meta(op):
58    def decorator(f):
59        def add_func(op):
60            meta_funcs[op] = f
61
62        pytree.tree_map_(add_func, op)
63        return f
64
65    return decorator
66
67
68@register_meta([aten.add.Tensor, aten.sub.Tensor])
69def binary_meta(a, b):
70    return a.new_empty(a.shape)
71
72
73@register_meta(aten.cat.default)
74def cat_meta(tensors, dim=0):
75    concat_length = 0
76    shape = tensors[0].shape
77    for tensor in tensors:
78        for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
79            if idx == dim:
80                concat_length = concat_length + length
81            else:
82                assert length == common_length
83    new_shape = list(shape)
84    new_shape[dim] = concat_length
85    return tensors[0].new_empty(new_shape)
86
87
88@register_meta([aten.narrow_copy.default])
89def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
90    shape = []
91    for i, x in enumerate(a.shape):
92        if i == dim:
93            shape.append(length)
94        else:
95            shape.append(x)
96    return a.new_empty(tuple(shape))
97
98
99@register_meta([aten.expand.default])
100def expand_symint_meta(a, size, implicit=False):
101    return a.new_empty(size)
102
103
104def create_contiguous(shape):
105    strides = [1]
106    for dim in reversed(shape[:-1]):
107        strides.append(dim * strides[-1])
108    return list(reversed(strides))
109
110
111class FakeSymbolicTensor(torch.Tensor):
112    @staticmethod
113    def __new__(
114        cls,
115        sym_shape,
116        sym_strides,
117        dtype,
118        layout,
119        requires_grad,
120        device,
121        storage_offset=0,
122    ):
123        # TODO: this is wrong in general
124        sym_stride = create_contiguous(sym_shape)
125        r = torch.Tensor._make_wrapper_subclass(
126            cls,
127            sym_shape,
128            sym_stride,
129            storage_offset,
130            dtype=dtype,
131            layout=layout,
132            requires_grad=requires_grad,
133            device=device,
134        )
135        return r
136
137    __torch_function__ = _disabled_torch_function_impl
138
139    def new_empty(self, shape):
140        return FakeSymbolicTensor(
141            shape, None, self.dtype, self.layout, self.requires_grad, self.device
142        )
143
144    @classmethod
145    def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
146        if func_overload in meta_funcs:
147            return meta_funcs[func_overload](*args, **kwargs)
148
149        if func_overload == torch.ops.aten.new_empty.default:
150            self = args[0]
151            shape = args[1]
152            return FakeSymbolicTensor(
153                shape,
154                self.stride(),
155                self.dtype,
156                self.layout,
157                self.requires_grad,
158                self.device,
159            )
160
161        raise RuntimeError(f"operator {func_overload} not supported")
162
163
164def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
165    from torch._dynamo.source import ConstantSource
166
167    if source is None:
168        source = ConstantSource(name)
169    constraint_dims = [None] * arg.dim()
170    if dynamic_dims is None:
171        dynamic_dims = [DimDynamic.DUCK] * arg.dim()
172    (
173        sym_shapes,
174        sym_strides,
175        sym_storage_offset,
176    ) = shape_env.create_symbolic_sizes_strides_storage_offset(
177        arg,
178        source=source,
179        symbolic_context=StatelessSymbolicContext(
180            dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
181        ),
182    )
183    return FakeSymbolicTensor(
184        sym_shapes,
185        sym_strides,
186        arg.dtype,
187        arg.layout,
188        arg.requires_grad,
189        arg.device,
190        sym_storage_offset,
191    )
192
193
194def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
195    from torch._dynamo.source import ConstantSource
196
197    symbol = shape_env.create_symbol(
198        val,
199        source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
200        dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
201        constraint_dim=None,
202        **kwargs,
203    )
204    return cls(SymNode(symbol, shape_env, pytype, hint=val))
205
206
207# TODO: default duck to False
208def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
209    return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
210
211
212def create_symbool(shape_env, b: bool) -> SymBool:
213    return create_symtype(SymBool, bool, shape_env, b)
214
215
216def create_symfloat(shape_env, f: float) -> SymFloat:
217    return create_symtype(SymFloat, float, shape_env, f)
218
219
220@skipIfTorchDynamo(
221    "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
222)
223class TestPySymInt(TestCase):
224    def test_arith_ops(self):
225        shape_env = ShapeEnv()
226        symints = []
227        for i in range(2, 5):
228            symints.append((i, create_symint(shape_env, i)))
229
230        ops = [
231            operator.add,
232            operator.sub,
233            operator.floordiv,
234            operator.mul,
235            operator.mod,
236        ]
237
238        for op in ops:
239            for args in itertools.permutations(symints, 2):
240                if not isinstance(args[0][1], int) and (
241                    (op != operator.mod or op != operator.floordiv) and args[1][0] != 0
242                ):
243                    self.assertTrue(
244                        op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
245                    )
246
247    def test_reverse_arith_ops(self):
248        shape_env = ShapeEnv()
249
250        a = create_symint(shape_env, 2)
251        self.assertTrue(5 // a == 5 // 2)
252
253        a = create_symint(shape_env, 2)
254        self.assertTrue(5 * a == 5 * 2)
255
256    def test_sympify_symint(self):
257        shape_env = ShapeEnv()
258        a = create_symint(shape_env, 2)
259        self.assertIs(sympy.sympify(a), a.node.expr)
260        b = create_symfloat(shape_env, 3.0)
261        self.assertIs(sympy.sympify(b), b.node.expr)
262        c = create_symbool(shape_env, True)
263        self.assertIs(sympy.sympify(c), c.node.expr)
264
265    def test_roundtrip(self):
266        shape_env = ShapeEnv()
267        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
268
269        self.assertTrue(not isinstance(x.shape[0], SymNode))
270        self.assertTrue(isinstance(x.shape[0], SymInt))
271
272        self.assertTrue(x.shape[0] == 5)
273        self.assertTrue(x.shape[1] == 4)
274        self.assertTrue(x.shape[2], 3)
275
276        self.assertTrue(x.size()[0], 5)
277        self.assertTrue(x.size()[1], 4)
278        # Should be simplifiable to an integer.
279        # Ref: https://github.com/pytorch/pytorch/pull/107492
280        self.assertTrue(isinstance(x.size()[1], SymInt))
281        self.assertTrue(
282            isinstance(x.size()[1].node.maybe_as_int(), int)
283        )  # due to guard above
284        self.assertTrue(x.size()[2] == 3)
285
286        self.assertTrue(x.size(0) == 5)
287        self.assertTrue(x.size(1) == 4)
288        self.assertTrue(x.size(2) == 3)
289        self.assertTrue(isinstance(x.size(2), SymInt))
290        self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
291
292        y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
293        self.assertTrue(isinstance(y.storage_offset(), SymInt))
294        self.assertTrue(y.storage_offset() == 12)
295
296    def test_binary(self):
297        shape_env = ShapeEnv()
298        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
299        y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
300
301        z = x + y
302        self.assertTrue(z.shape[0] == 5)
303        self.assertTrue(z.shape[1] == 4)
304        self.assertTrue(z.shape[2] == 3)
305
306        # broadcasting
307        y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
308        z = x + y
309        self.assertTrue(z.shape[0] == 5)
310        self.assertTrue(z.shape[1] == 4)
311        self.assertTrue(z.shape[2] == 3)
312
313    def test_symint_args(self):
314        shape_env = ShapeEnv()
315        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
316        y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
317        LAST_DIM = 2
318        z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
319        self.assertTrue(z.shape[2] == y.shape[2])
320
321        # arithmetic expr with two symints
322        z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
323        self.assertTrue(z.shape[2] == 2)
324
325        # arithmetic expr with a symint and python int
326        z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
327        self.assertTrue(z.shape[2] == 2)
328
329    def test_symint_vargs(self):
330        shape_env = ShapeEnv()
331        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
332        y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
333
334        # varargs
335        z = y.expand(x.shape[0], y.shape[1], x.shape[2])
336        self.assertTrue(z.shape[0] == 5)
337        self.assertTrue(z.shape[1] == 4)
338        self.assertTrue(z.shape[2] == 3)
339
340        # shape list
341        z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
342        self.assertTrue(z.shape[0] == 5)
343        self.assertTrue(z.shape[1] == 4)
344        self.assertTrue(z.shape[2] == 3)
345
346        # mixed python symints and ints
347        z = y.expand(x.shape[0], y.shape[1], 3)
348        self.assertTrue(z.shape[0] == 5)
349        self.assertTrue(z.shape[1] == 4)
350        self.assertTrue(z.shape[2] == 3)
351
352        # mixed python symints and ints in a list
353        z = y.expand((x.shape[0], y.shape[1], 3))
354        self.assertTrue(z.shape[0] == 5)
355        self.assertTrue(z.shape[1] == 4)
356        self.assertTrue(z.shape[2] == 3)
357
358        # mixed python symints and ints
359        z = y.expand(5, y.shape[1], x.shape[2])
360        self.assertTrue(z.shape[0] == 5)
361        self.assertTrue(z.shape[1] == 4)
362        self.assertTrue(z.shape[2] == 3)
363
364        # mixed python ints and symints in a list
365        z = y.expand((5, y.shape[1], x.shape[2]))
366        self.assertTrue(z.shape[0] == 5)
367        self.assertTrue(z.shape[1] == 4)
368        self.assertTrue(z.shape[2] == 3)
369
370        z = y.expand((y.shape[1],))
371        z = y.expand(y.shape[1])
372
373    def test_stride(self):
374        shape_env = ShapeEnv()
375        x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
376        self.assertIsInstance(x.stride()[0], SymInt)
377
378    def test_size_expressions(self):
379        shape_env = ShapeEnv()
380        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
381        expand_x = x.expand(x.shape[0], x.shape[0])
382        if expand_x.shape[0] > 3:
383            result = expand_x + expand_x
384        else:
385            result = expand_x + expand_x
386
387        gt_op, _bt = shape_env.guards[-1]
388        self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
389        self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
390        self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
391        self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
392
393    def test_floordiv_static(self):
394        shape_env = ShapeEnv()
395        s0 = create_symint(shape_env, 8)
396        # This was extracted from
397        # python test/inductor/test_cuda_cpp_wrapper.py -k
398        # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
399        bool(s0 % 2 == 0)
400        bool(s0 % (s0 // 2) == 0)
401        bool(2 * (s0 // 2) == s0)
402        self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
403
404    def test_numel(self):
405        shape_env = ShapeEnv()
406        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
407        self.assertIsInstance(x.numel(), torch.SymInt)
408        self.assertIsInstance(torch.numel(x), torch.SymInt)
409
410        x = torch.rand(3, 3)
411        self.assertIsInstance(x.numel(), int)
412        self.assertIsInstance(torch.numel(x), int)
413
414    def test_int_to_float(self):
415        shape_env = ShapeEnv()
416        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
417        r = torch.sym_float(x.shape[0])
418        self.assertIsInstance(r, torch.SymFloat, msg=type(r))
419
420    def test_aten_ops(self):
421        shape_env = ShapeEnv()
422        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
423        torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
424
425        shape_env = ShapeEnv()
426        x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
427        torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
428
429    def test_fx_trace_intlist(self):
430        class CustomModule(torch.nn.Module):
431            def forward(self, x):
432                bs, c, h, w = x.shape
433                return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
434
435        m = CustomModule()
436        x = torch.rand(1, 3, 4, 4)
437        # should not TypeError: pad(): argument 'pad' (position 2) must be
438        # tuple of ints, not tuple
439        torch.fx.symbolic_trace(m)
440
441    def test_meta_symint(self):
442        shape_env = ShapeEnv()
443        a0 = create_symint(shape_env, 2)
444        r = torch.empty(a0, device="meta")
445        self.assertIsInstance(r.shape[0], SymInt)
446
447    def test_guard_int(self):
448        shape_env = ShapeEnv()
449        a0 = create_symint(shape_env, 2)
450        self.assertEqual(guard_int(a0), 2)
451        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
452
453    def test_prefer_deferred_runtime_assertions_over_guards(self):
454        shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
455        s0 = create_symint(shape_env, 2)
456        self.assertEqual(guard_int(s0), 2)
457        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
458
459        shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
460        s0 = create_symint(shape_env, 2)
461        self.assertTrue(expect_true(s0 == 2))
462        self.assertEqual(len(shape_env.guards), 0)
463        self.assertExpectedInline(
464            str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
465            """[Eq(s0, 2)]""",
466        )
467
468    def test_sym_int(self):
469        shape_env = ShapeEnv()
470        a0 = create_symint(shape_env, 5)
471        r = sym_int(a0)
472        self.assertEqual(r, 5)
473        self.assertIsInstance(r, torch.SymInt, msg=type(r))
474        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
475
476        a1 = create_symint(shape_env, 7)
477        r = sym_int(a1 / 2)
478        self.assertEqual(guard_int(r), 3)
479        self.assertIsInstance(r, torch.SymInt, msg=type(r))
480        self.assertExpectedInline(
481            str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
482        )
483
484        a3 = create_symint(shape_env, 3)
485        r = sym_int(2.0 * torch.sym_float(a3))
486        self.assertEqual(guard_int(r), 6)
487        self.assertIsInstance(r, torch.SymInt, msg=type(r))
488        self.assertExpectedInline(
489            str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
490        )
491
492    def test_sym_sqrt(self):
493        shape_env = ShapeEnv()
494        a0 = create_symint(shape_env, 4)
495        r = torch._sym_sqrt(a0)
496        self.assertEqual(r, 2)
497        self.assertIsInstance(r, torch.SymFloat, msg=type(r))
498        self.assertExpectedInline(
499            str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
500        )
501
502    def test_sym_floor(self):
503        shape_env = ShapeEnv()
504        a0 = create_symint(shape_env, 5)
505        r = math.floor(a0 / 2)
506        self.assertEqual(r, 2)
507        self.assertIsInstance(r, torch.SymInt, msg=type(r))
508        self.assertExpectedInline(
509            str(shape_env.guards[0][0]),
510            """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
511        )
512        r = math.floor(3.0 * a0)
513        self.assertEqual(r, 15)
514        self.assertIsInstance(r, torch.SymInt, msg=type(r))
515        self.assertExpectedInline(
516            str(shape_env.guards[1][0]),
517            """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
518        )
519
520    def test_sym_trunc(self):
521        shape_env = ShapeEnv()
522        a0 = create_symint(shape_env, 5)
523        r = math.trunc(a0 / 2)
524        self.assertEqual(r, 2)
525        self.assertIsInstance(r, torch.SymInt, msg=type(r))
526        self.assertExpectedInline(
527            str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
528        )
529        r = torch.sym_int(torch.sym_sqrt(a0))
530        self.assertEqual(r, 2)
531        self.assertIsInstance(r, torch.SymInt, msg=type(r))
532        self.assertExpectedInline(
533            str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
534        )
535
536    def test_sym_ceil(self):
537        shape_env = ShapeEnv()
538        a0 = create_symint(shape_env, 5)
539        r = math.ceil(a0 / 2)
540        self.assertEqual(r, 3)
541        self.assertIsInstance(r, torch.SymInt, msg=type(r))
542        self.assertExpectedInline(
543            str(shape_env.guards[0][0]),
544            """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
545        )
546        r1 = 3.0 * a0
547        r = math.floor(r1)
548        self.assertEqual(r, 15)
549        self.assertIsInstance(r, torch.SymInt, msg=type(r))
550        self.assertExpectedInline(
551            str(shape_env.guards[1][0]),
552            """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
553        )
554
555    def test_sym_ite(self):
556        shape_env = ShapeEnv()
557        t = create_symint(shape_env, 5)
558        f = create_symint(shape_env, 4)
559        b1 = True
560        r1 = torch.sym_ite(b1, t, f)
561        self.assertTrue(r1 is t)
562        b2 = False
563        r2 = torch.sym_ite(b2, t, f)
564        self.assertTrue(r2 is f)
565        b3 = t == 5
566        r3 = torch.sym_ite(b3, t, f)
567        self.assertEqual(len(shape_env.guards), 0)
568        self.assertEqual(r3, 5)
569        self.assertEqual(type(t), type(r3))
570        self.assertExpectedInline(
571            str(shape_env.guards[0][0]),
572            """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
573        )
574        b4 = f == 5
575        r4 = torch.sym_ite(b4, t, f)
576        self.assertEqual(len(shape_env.guards), 1)
577        self.assertEqual(r4, 4)
578        self.assertEqual(type(f), type(r4))
579        self.assertExpectedInline(
580            str(shape_env.guards[1][0]),
581            """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
582        )
583
584    def test_tracing_sym_ite(self):
585        def f(x):
586            b = x.shape[0] == 5
587            ret = torch.sym_ite(b, x.shape[0], x.shape[1])
588            return ret
589
590        gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
591        self.assertEqual(len(gm.shape_env.guards), 0)
592        self.assertExpectedInline(
593            gm.code.strip(),
594            """\
595def forward(self, x_1):
596    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
597    eq = sym_size_int == 5
598    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1);  x_1 = None
599    sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1);  eq = sym_size_int = sym_size_int_1 = None
600    return sym_ite""",
601        )
602        r1 = gm(torch.ones(4, 5))
603        self.assertIsInstance(r1, int)
604        self.assertEqual(r1, 5)
605        r2 = gm(torch.ones(5, 4))
606        self.assertIsInstance(r2, int)
607        self.assertEqual(r2, 5)
608
609    def test_int_conversion(self):
610        shape_env = ShapeEnv()
611        a0 = create_symint(shape_env, 2)
612        int(a0)
613        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
614
615    def test_data_dependent_guard(self):
616        shape_env = ShapeEnv()
617        s0 = shape_env.create_unbacked_symint()
618        self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
619
620    def test_data_dependent_guard_propagate_real_tensors(self):
621        shape_env = ShapeEnv()
622        s0 = shape_env.create_unbacked_symint()
623        shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
624        self.assertEqual(bool(s0 == 0), True)
625
626    def test_expect_true_basic(self):
627        shape_env = ShapeEnv()
628        i0 = shape_env.create_unbacked_symint()
629        i0_sym = i0.node.expr
630        # This doesn't error
631        self.assertTrue(expect_true(i0 == 0))
632        # This generates a deferred runtime assert via replacement
633        self.assertEqual(shape_env.replacements[i0_sym], 0)
634        # After expecting true, guards now resolve given the runtime assert
635        bool(i0 == 0)
636
637    def test_expect_true_with_s0(self):
638        shape_env = ShapeEnv()
639        s0 = create_symint(shape_env, 5)
640        i0 = shape_env.create_unbacked_symint()
641        self.assertTrue(expect_true(i0 < s0))
642        self.assertExpectedInline(
643            str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
644            """[u0 < s0]""",
645        )
646        self.assertTrue(i0 < s0)
647        self.assertTrue(i0 != s0)
648        self.assertFalse(i0 > s0)
649        self.assertFalse(i0 >= s0)
650
651    def test_expect_true_prefer_later(self):
652        shape_env = ShapeEnv()
653        i0 = shape_env.create_unbacked_symint()
654        i1 = shape_env.create_unbacked_symint()
655        i1_sym = i1.node.expr
656        self.assertTrue(expect_true(i0 + i1 == 10))
657        # Importantly, this is put in i1, not i0!
658        self.assertExpectedInline(
659            str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
660            """[Eq(u0 + u1, 10)]""",
661        )
662        self.assertTrue(i0 + i1 == 10)
663        # NB: We currently don't support deriving that we can substitute
664        # i0 + i1 with 10; maybe we should, but this means our rewriting
665        # system is no longer confluent (it's probably OK though, because
666        # you're unlikely to get other equalities like this on the
667        # unbacked SymInts.)
668
669    def test_unbacked_substitution(self):
670        shape_env = ShapeEnv()
671        i0 = shape_env.create_unbacked_symint()
672        i1 = shape_env.create_unbacked_symint()
673        _constrain_range_for_size(i0)
674        _constrain_range_for_size(i1)
675        self.assertTrue(expect_true(i0 == i1 * 4))
676        self.assertExpectedInline(str(i0), """u0""")
677
678        i2 = shape_env.create_unbacked_symint()
679        i3 = shape_env.create_unbacked_symint()
680        _constrain_range_for_size(i2)
681        _constrain_range_for_size(i3)
682        self.assertTrue(expect_true(i2 * 4 == i3))
683        self.assertExpectedInline(str(i3), """u3""")
684
685    def test_avoid_unbacked_substitution(self):
686        shape_env = ShapeEnv()
687        i0 = shape_env.create_unbacked_symint()
688        _constrain_range_for_size(i0)
689        i1 = shape_env.create_unbacked_symint()
690        _constrain_range_for_size(i1)
691        self.assertTrue(expect_true(i0 == 10 - i1))
692        self.assertExpectedInline(str(i0), """u0""")
693
694    def test_expect_true_double_digits(self):
695        shape_env = ShapeEnv()
696        ia = [shape_env.create_unbacked_symint() for _ in range(11)]  # allocate 10
697        self.assertEqual(str(ia[-1]), "u10")
698        self.assertTrue(expect_true(sum(ia) == 20))
699        self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
700
701    def test_expect_true_refine_range(self):
702        shape_env = ShapeEnv()
703        for i, rel in enumerate(
704            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
705        ):
706            with self.subTest(f"i = {i}"):
707                i0 = shape_env.create_unbacked_symint()
708                self.assertTrue(expect_true(rel(i0)))
709                self.assertTrue(statically_known_true(i0 != 3))
710                self.assertTrue(statically_known_true(i0 != 4))
711                self.assertFalse(statically_known_true(i0 != 5))
712                self.assertFalse(statically_known_true(i0 != 6))
713                self.assertTrue(statically_known_true(i0 > 4))
714                self.assertTrue(statically_known_true(i0 >= 5))
715
716        for i, rel in enumerate(
717            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
718        ):
719            with self.subTest(f"i = {i}"):
720                i0 = shape_env.create_unbacked_symint()
721                self.assertTrue(expect_true(rel(i0)))
722                self.assertFalse(statically_known_true(i0 != 2))
723                self.assertFalse(statically_known_true(i0 != 3))
724                self.assertTrue(statically_known_true(i0 != 4))
725                self.assertTrue(statically_known_true(i0 != 5))
726                self.assertTrue(statically_known_true(i0 < 4))
727                self.assertTrue(statically_known_true(i0 <= 5))
728
729    def test_guard_refine_range(self):
730        shape_env = ShapeEnv()
731        for i, rel in enumerate(
732            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
733        ):
734            with self.subTest(f"i = {i}"):
735                i0 = create_symint(shape_env, 10, duck=False)
736                self.assertTrue(bool(rel(i0)))
737                self.assertTrue(statically_known_true(i0 != 3))
738                self.assertTrue(statically_known_true(i0 != 4))
739                self.assertFalse(statically_known_true(i0 != 5))
740                self.assertFalse(statically_known_true(i0 != 6))
741                self.assertTrue(statically_known_true(i0 > 4))
742                self.assertTrue(statically_known_true(i0 >= 5))
743
744        for i, rel in enumerate(
745            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
746        ):
747            with self.subTest(f"i = {i}"):
748                i0 = create_symint(shape_env, 2, duck=False)
749                self.assertFalse(bool(rel(i0)))
750                self.assertFalse(statically_known_true(i0 != 3))
751                self.assertFalse(statically_known_true(i0 != 4))
752                self.assertTrue(statically_known_true(i0 != 5))
753                self.assertTrue(statically_known_true(i0 != 6))
754                self.assertTrue(statically_known_true(i0 <= 4))
755                self.assertTrue(statically_known_true(i0 < 5))
756
757        for i, rel in enumerate(
758            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
759        ):
760            with self.subTest(f"i = {i}"):
761                i0 = create_symint(shape_env, 2, duck=False)
762                self.assertTrue(bool(rel(i0)))
763                self.assertFalse(statically_known_true(i0 != 2))
764                self.assertFalse(statically_known_true(i0 != 3))
765                self.assertTrue(statically_known_true(i0 != 4))
766                self.assertTrue(statically_known_true(i0 != 5))
767                self.assertTrue(statically_known_true(i0 < 4))
768                self.assertTrue(statically_known_true(i0 <= 3))
769
770        for i, rel in enumerate(
771            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
772        ):
773            with self.subTest(f"i = {i}"):
774                i0 = create_symint(shape_env, 10, duck=False)
775                self.assertFalse(bool(rel(i0)))
776                self.assertTrue(statically_known_true(i0 != 2))
777                self.assertTrue(statically_known_true(i0 != 3))
778                self.assertFalse(statically_known_true(i0 != 4))
779                self.assertFalse(statically_known_true(i0 != 5))
780                self.assertTrue(statically_known_true(i0 >= 4))
781                self.assertTrue(statically_known_true(i0 > 3))
782
783    def test_mul_int_oo_nan(self):
784        shape_env = ShapeEnv()
785        s0 = create_symint(shape_env, 5, duck=False)
786        s1 = create_symint(shape_env, 6, duck=False)
787        s2 = create_symint(shape_env, 5, duck=False)
788        bool(s0 * (s1 // s0) == s2)
789
790    def test_non_overlapping_and_dense(self):
791        shape_env = ShapeEnv()
792        a0 = create_symint(shape_env, 5)
793        r = torch.empty_strided((a0, 7), (1, a0), device="meta")
794        self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
795
796    def test_non_overlapping_and_dense_unbacked(self):
797        shape_env = ShapeEnv()
798        u0 = shape_env.create_unbacked_symint()
799        torch._check_is_size(u0)
800        cf = torch.ops.aten.is_non_overlapping_and_dense.default
801
802        self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
803        self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
804        self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
805        self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
806
807        self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
808        self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
809        self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
810        self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
811
812        Max = torch.sym_max
813        # NB: This only works because we're able to determine this tensor is
814        # contiguous. transpose(0, 1) makes it stop working
815        self.assertTrue(
816            cf(
817                torch.empty_strided(
818                    (2, 3, 1, u0),
819                    (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
820                    device="meta",
821                )
822            )
823        )
824
825    def test_numpy_sym_max(self):
826        self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
827        self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
828        self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
829        self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
830        self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
831        self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
832
833    def test_numpy_sym_min(self):
834        self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
835        self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
836        self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
837        self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
838        self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
839        self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
840
841    def test_debug_has_internal_overlap_unbacked(self):
842        shape_env = ShapeEnv()
843        u0 = shape_env.create_unbacked_symint()
844        torch._check_is_size(u0)
845        cf = torch._debug_has_internal_overlap
846        self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
847        self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
848        self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
849        self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
850        Max = torch.sym_max
851        self.assertEqual(
852            cf(
853                torch.empty_strided(
854                    (2, 3, 1, u0),
855                    (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
856                    device="meta",
857                )
858            ),
859            0,
860        )
861
862        # Wobbling these to zero is OK too
863        self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
864        self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
865
866    def test_specialize_zero_one(self):
867        shape_env = ShapeEnv(specialize_zero_one=True)
868        a0 = create_symint(shape_env, 5)
869        assert a0 != 1
870        self.assertEqual(len(shape_env.guards), 0)
871
872        shape_env = ShapeEnv(specialize_zero_one=False)
873        a0 = create_symint(shape_env, 5)
874        assert a0 != 1
875        self.assertEqual(len(shape_env.guards), 1)
876
877    def test_duck_shape(self):
878        shape_env = ShapeEnv(duck_shape=True)
879        a0 = create_symint(shape_env, 5)
880        a1 = create_symint(shape_env, 5)
881        assert a0 == a1
882        self.assertEqual(len(shape_env.guards), 0)
883
884        shape_env = ShapeEnv(duck_shape=False)
885        a0 = create_symint(shape_env, 5)
886        a1 = create_symint(shape_env, 5)
887        assert a0 == a1
888        self.assertEqual(len(shape_env.guards), 1)
889
890    def test_int_bool(self):
891        # See https://github.com/pytorch/pytorch/issues/95981
892        shape_env = ShapeEnv(duck_shape=True)
893        a0 = create_symint(shape_env, 5)
894        assert a0
895        self.assertEqual(len(shape_env.guards), 0)
896
897    def test_symint_as_scalar(self):
898        shape_env = ShapeEnv()
899        a0 = create_symint(shape_env, 2)
900
901        sym_int_encountered = False
902
903        class TestSymInt(TorchDispatchMode):
904            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
905                assert func == torch.ops.aten.add.Tensor
906
907                nonlocal sym_int_encountered
908                # WARNING: do not do identity tests on the outer
909                # SymInt/SymFloat, they are NOT STABLE
910                sym_int_encountered = kwargs["alpha"].node is a0.node
911                kwargs["alpha"] = 0
912                return func(*args)
913
914        x = torch.rand([4, 4])
915        with TestSymInt():
916            y = torch.add(x, x, alpha=a0)
917
918        self.assertTrue(sym_int_encountered)
919
920    def test_deepcopy(self):
921        shape_env = ShapeEnv()
922        a0 = create_symint(shape_env, 2)
923        assert a0 < 4
924        new_shape_env = copy.deepcopy(shape_env)
925        self.assertEqual(len(new_shape_env.guards), 1)
926
927    def test_print_readable_with_symints(self):
928        def f(a, b):
929            dim0 = a.shape[0] + b.shape[0]
930            dim1 = a.shape[1] + b.shape[1]
931            d = a.new_empty(dim0, dim1)
932            d = torch.ops.aten.native_dropout(d, 0.5, train=True)
933            return d
934
935        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
936        out = fx_g.print_readable(print_output=False)
937
938        self.assertExpectedInline(
939            out.strip(),
940            """\
941class f(torch.nn.Module):
942    def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
943        # No stacktrace found for following nodes
944        sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
945        sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
946        add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1;  sym_size_int = sym_size_int_1 = None
947        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
948        sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1);  b_1 = None
949        add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3;  sym_size_int_2 = sym_size_int_3 = None
950        new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False);  a_1 = add = add_1 = None
951        native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True);  new_empty = None
952        getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
953        getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1];  native_dropout = None
954        return (getitem, getitem_1)""",  # noqa: B950
955        )
956
957    def test_statically_known_true(self):
958        shape_env = ShapeEnv()
959        s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
960
961        # Statically known true
962        self.assertTrue(statically_known_true(True))
963        self.assertTrue(statically_known_true(s2 == s2))
964        self.assertTrue(statically_known_true(s2 * s3 > s3))
965        self.assertTrue(statically_known_true(s3 * s4 > s4))
966        self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
967
968        # Statically known false
969        self.assertFalse(statically_known_true(False))
970        self.assertFalse(statically_known_true(s3 * s4 <= s4))
971        self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
972
973        # True for hints, but not known statically
974        self.assertFalse(statically_known_true(s2 + s2 == s4))
975        self.assertFalse(statically_known_true(s4 % s2 == 0))
976        self.assertFalse(statically_known_true(s2 != s3))
977        self.assertFalse(statically_known_true(s3 * s4 > s2))
978
979        # False for hints, but not known statically
980        self.assertFalse(statically_known_true(s2 == s3))
981        self.assertFalse(statically_known_true(s2 > s3))
982        self.assertFalse(statically_known_true(s3 + s3 == s4))
983
984        # No guards should be generated
985        self.assertEqual(len(shape_env.guards), 0)
986
987    def test_ephemeral_source_simplification(self):
988        from torch._dynamo.source import EphemeralSource
989
990        # For full robustness, ensure the ephemeral source symbols are simplified out regardless
991        # of construction order or check order.
992        for construct_ephemeral_first, x_first_in_check in itertools.product(
993            [False, True], [False, True]
994        ):
995            shape_env = ShapeEnv()
996            shape = (5, 10)
997            dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
998            x = create_symbolic_tensor(
999                "x",
1000                torch.randn(*shape),
1001                shape_env,
1002                source=(EphemeralSource() if construct_ephemeral_first else None),
1003                dynamic_dims=dynamic_dims,
1004            )
1005            y = create_symbolic_tensor(
1006                "y",
1007                torch.randn(*shape),
1008                shape_env,
1009                source=(EphemeralSource() if not construct_ephemeral_first else None),
1010                dynamic_dims=dynamic_dims,
1011            )
1012            t_with_ephemeral = x if construct_ephemeral_first else y
1013
1014            def _get_ephemeral_source_symbols(t):
1015                return [
1016                    s.node.expr
1017                    for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
1018                    if isinstance(s, torch.SymInt)
1019                    and s.node.expr in shape_env.var_to_sources
1020                    and any(
1021                        source.is_ephemeral()
1022                        for source in shape_env.var_to_sources[s.node.expr]
1023                    )
1024                ]
1025
1026            # these checks should simplify out the ephemeral symbols, regardless of the
1027            # ordering x == y or y == x
1028            self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
1029            if x_first_in_check:
1030                torch._check(x.size() == y.size())
1031                torch._check(x.stride() == y.stride())
1032                torch._check(x.storage_offset() == y.storage_offset())
1033            else:
1034                torch._check(y.size() == x.size())
1035                torch._check(y.stride() == x.stride())
1036                torch._check(y.storage_offset() == x.storage_offset())
1037            self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
1038
1039    def test_ephemeral_source_unified_with_non_ephemeral_source(self):
1040        from torch._dynamo.source import EphemeralSource
1041
1042        for construct_ephemeral_first in (False, True):
1043            shape_env = ShapeEnv()
1044            shape = (5, 10)
1045            # use duck sizing here to ensure symbol reuse across x and y
1046            duck_dims = [DimDynamic.DUCK for _ in shape]
1047            x = create_symbolic_tensor(
1048                "x",
1049                torch.randn(*shape),
1050                shape_env,
1051                source=(EphemeralSource() if construct_ephemeral_first else None),
1052                dynamic_dims=duck_dims,
1053            )
1054            y = create_symbolic_tensor(
1055                "y",
1056                torch.randn(*shape),
1057                shape_env,
1058                source=(EphemeralSource() if not construct_ephemeral_first else None),
1059                dynamic_dims=duck_dims,
1060            )
1061
1062            # regardless of construction order, non-ephemeral sources should be preferred
1063            # first in the var_to_sources list for potential guarding later on
1064            for source_list in shape_env.var_to_sources.values():
1065                self.assertFalse(source_list[0].is_ephemeral())
1066
1067            self.assertEqual(x.size(), y.size())
1068            self.assertEqual(x.stride(), y.stride())
1069            self.assertEqual(x.storage_offset(), y.storage_offset())
1070
1071
1072@skipIfTorchDynamo(
1073    "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
1074)
1075class TestSymNumberMagicMethods(TestCase):
1076    def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
1077        with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
1078            return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
1079
1080    def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
1081        # Helper function
1082        # NB: don't use one as that will get specialized
1083        # TODO: We don't have to circuitously create the float, can just
1084        # create a symfloat directly
1085        seed_node = (create_symint(shape_env, 2) / 2.0).node
1086        bool_seed_node = (create_symint(shape_env, 2) == 2).node
1087
1088        def get_sym_inp(inp):
1089            # NB: this must come before int
1090            if isinstance(inp, bool):
1091                return torch.SymBool(to_node(bool_seed_node, inp))
1092            elif isinstance(inp, int):
1093                return torch.SymInt(to_node(seed_node, inp))
1094            else:
1095                return torch.SymFloat(to_node(seed_node, inp))
1096
1097        if fn == "float_pow":
1098            if inp1 < 0:
1099                return
1100
1101        if fn == "pow_by_natural":
1102            if isinstance(inp1, float) or isinstance(inp2, float):
1103                return
1104            if inp2 < 0:
1105                return
1106
1107        def maybe_xfail(inp1, inp2):
1108            if fn == "sym_sqrt" and inp1 < 0:
1109                # ValueError: math domain error
1110                return self.assertRaises((ValueError,))
1111            elif (
1112                fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
1113                and inp2 == 0
1114            ):
1115                # ZeroDivisionError: division by zero
1116                return self.assertRaises((ZeroDivisionError,))
1117            elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
1118                # ZeroDivisionError: 0.0 cannot be raised to a negative power
1119                return self.assertRaises((ZeroDivisionError,))
1120            elif (
1121                # TODO: dear catastrophe waitress,
1122                # this doesn't work
1123                fn in ["float_pow", "pow_by_natural"]
1124                and inp1 < 0
1125                and (
1126                    type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
1127                )
1128                and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
1129            ):
1130                # Complex result, which we do not support:
1131                # TypeError: Cannot convert complex to float
1132                return self.assertRaises((RuntimeError,))
1133            elif fn in ("lshift", "rshift") and not (
1134                isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
1135            ):
1136                # TypeError: unsupported operand type(s)
1137                return self.assertRaises((TypeError,))
1138            elif fn in ("lshift", "rshift") and inp2 < 0:
1139                # ValueError: math domain error
1140                return self.assertRaises((ValueError,))
1141            else:
1142                return contextlib.nullcontext()
1143
1144        lambda_apply = method_to_operator(fn)
1145
1146        def guard_fn(v):
1147            if type(v) in (SymBool, bool):
1148                return guard_bool(v)
1149            elif type(v) in (SymFloat, float):
1150                return guard_float(v)
1151            else:  # SymInt, int
1152                return guard_int(v)
1153
1154        # Get reference result
1155        with maybe_xfail(inp1, inp2):
1156            if is_unary_fn:
1157                ref_out = lambda_apply(inp1)
1158            else:
1159                ref_out = lambda_apply(inp1, inp2)
1160
1161        # Symified first arg
1162        sym_inp1 = get_sym_inp(inp1)
1163        with maybe_xfail(sym_inp1, inp2):
1164            if is_unary_fn:
1165                out = lambda_apply(sym_inp1)
1166            else:
1167                out = lambda_apply(sym_inp1, inp2)
1168            if fn not in sym_node.alternate_impl_if_hinted_methods:
1169                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1170            out = guard_fn(out)
1171            self.assertEqual(out, ref_out)
1172
1173        if is_unary_fn:
1174            return
1175
1176        # Symified second arg
1177        sym_inp2 = get_sym_inp(inp2)
1178        with maybe_xfail(inp1, sym_inp2):
1179            out = lambda_apply(inp1, sym_inp2)
1180            if fn not in sym_node.alternate_impl_if_hinted_methods:
1181                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1182            out = guard_fn(out)
1183            self.assertEqual(out, ref_out)
1184
1185        # Symified both args
1186        with maybe_xfail(sym_inp1, sym_inp2):
1187            out = lambda_apply(sym_inp1, sym_inp2)
1188            if fn not in sym_node.alternate_impl_if_hinted_methods:
1189                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1190            out = guard_fn(out)
1191            self.assertEqual(out, ref_out)
1192
1193    @parametrize("fn", list(sym_node.magic_methods.keys()))
1194    def test_bool_method(self, fn):
1195        # sym_ite has its own tests
1196        if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
1197            self.skipTest(f"{fn} is non-bool")
1198
1199        is_unary_fn = fn in sym_node.unary_methods
1200        shape_env = ShapeEnv()
1201        self._do_test(fn, True, False, shape_env, is_unary_fn)
1202
1203    @parametrize("fn", list(sym_node.magic_methods.keys()))
1204    @parametrize("first_type", ["int", "float"])
1205    @parametrize("second_type", ["int", "float"])
1206    def test_method(self, fn, first_type, second_type):
1207        if first_type == "float":
1208            # TODO: Hmm, this looks like we skip all floats
1209            self.skipTest(f"{fn} is not a float magic method")
1210
1211        if (
1212            first_type == "int" or second_type == "int"
1213        ) and fn in sym_node.only_float_magic_methods:
1214            self.skipTest(f"{fn} is not an int method")
1215
1216        if second_type == "float" and fn in ["mod"]:
1217            self.skipTest(f"{fn} only handles int")
1218
1219        is_unary_fn = fn in sym_node.unary_methods or fn == "round"
1220        # Second argument is ignored for unary function. So only run for one type
1221        if is_unary_fn and second_type == "float":
1222            self.skipTest(f"{fn} is unary and already tested")
1223
1224        if fn in sym_node.bool_magic_methods:
1225            self.skipTest(f"{fn} is bool")
1226
1227        # Only floats here since these will be converted to int if necessary.
1228        # We also ignore complex and bool.
1229        values = (
1230            0.0,
1231            1.0,
1232            0.5 if fn in ("sym_acos", "sym_asin") else 2.5,  # avoid math domain error
1233        )
1234
1235        neg_values = tuple(-x for x in values)
1236
1237        for inp1, inp2 in itertools.chain(
1238            itertools.product(values, values),
1239            itertools.product(values, neg_values),
1240            itertools.product(neg_values, values),
1241            itertools.product(neg_values, neg_values),
1242        ):
1243            if first_type == "int":
1244                inp1 = int(inp1)
1245            if second_type == "int":
1246                inp2 = int(inp2)
1247
1248            shape_env = ShapeEnv()
1249
1250            self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
1251
1252    def get_constant_bool(self, val):
1253        return SymBool(torch._C._get_constant_bool_symnode(val))
1254
1255    @unittest.expectedFailure
1256    def test_symint_hashing(self):
1257        shape_env = ShapeEnv()
1258        hash(create_symint(shape_env, 3))
1259
1260    def test_symnode_hashing(self):
1261        shape_env = ShapeEnv()
1262
1263        # These all trigger specialization when hashed
1264        hash(create_symbool(shape_env, True))
1265        # We should be passing in float here, but create_symbol currently
1266        # only supports int
1267        hash(create_symfloat(shape_env, 3.0))
1268
1269        # NestedInt (SymInt), constant SymBool, SymNode are hashable
1270        j1 = torch._C._get_nested_int(1, 1)
1271        j1_copy = torch._C._get_nested_int(1, 1)
1272        j2 = torch._C._get_nested_int(2, 1)
1273        t = self.get_constant_bool(True)
1274        t_copy = self.get_constant_bool(True)
1275        f = self.get_constant_bool(False)
1276        n = create_symint(shape_env, 3).node
1277        m = self.get_constant_bool(True).node
1278
1279        self.assertIs(j1 == j1_copy, True)
1280        self.assertEqual(hash(j1), hash(j1_copy))
1281        self.assertIs(j1 == j2, False)
1282        self.assertNotEqual(hash(j1), hash(j2))
1283        self.assertIs(t == t_copy, True)
1284        self.assertEqual(hash(t), hash(t_copy))
1285        self.assertIs(t == f, False)
1286        self.assertNotEqual(hash(t), hash(f))
1287
1288        hash(n)
1289        hash(m)
1290
1291    def test_symint_deepcopy(self):
1292        shape_env = ShapeEnv()
1293
1294        symnodes = (torch._C._get_nested_int(1, 1),)
1295        deepcopied_symnodes = copy.deepcopy(symnodes)
1296        self.assertEqual(symnodes, deepcopied_symnodes)
1297
1298    def test_non_symbolic_symnode(self):
1299        j1 = torch._C._get_nested_int(1, 1)
1300        j2 = torch._C._get_nested_int(1, 1)
1301        j3 = torch._C._get_nested_int(3, 1)
1302
1303        self.assertIsInstance(j1, torch.SymInt)
1304        self.assertNotIsInstance(j1, int)
1305
1306        with self.assertRaisesRegex(
1307            RuntimeError, "add not supported by NestedIntSymNode"
1308        ):
1309            j1 + 3
1310
1311        self.assertFalse(j1 == 3)
1312        with self.assertRaisesRegex(RuntimeError, "indeterminate"):
1313            self.assertFalse(3 >= j2)
1314
1315        self.assertIs(j1 == j1, True)
1316        self.assertIs(j1 == j2, True)
1317        self.assertIs(j1 == j3, False)
1318        self.assertIs(j1 != j3, True)
1319        self.assertIs(j1 != j2, False)
1320
1321        x = self.get_constant_bool(True)
1322        #
1323        # Unary
1324        #
1325        # op(constant SymBool)
1326        self.assertIs(x.__sym_not__(), False)
1327
1328        #
1329        # Binary
1330        #
1331        # op(constant SymBool, bool)
1332        # op(constant SymBool, constant SymBool)
1333        # op(bool, constant SymBool)
1334        self.assertIs(operator.and_(x, True), True)
1335        self.assertIs(operator.and_(x, x), True)
1336        self.assertIs(operator.and_(True, x), True)
1337
1338        # op(symbolic SymBool, constant Symbool)
1339        # op(constant SymBool, symbolic Symbool)
1340        shape_env = ShapeEnv()
1341        a = create_symint(shape_env, 2)
1342        b = create_symint(shape_env, 2)
1343        c = a == b  # symbolic SymBool
1344        d = self.get_constant_bool(True)
1345        e = operator.and_(c, d)
1346        f = operator.and_(d, c)
1347        self.assertTrue(is_symbolic(e))
1348        self.assertTrue(is_symbolic(f))
1349        self.assertIs(e.node.guard_bool("", 0), True)
1350        self.assertIs(f.node.guard_bool("", 0), True)
1351
1352        # Comparing sizes
1353        sz1 = torch.Size([j1, j1, j1])
1354        sz2 = torch.Size([j1, j1, j1])
1355        self.assertIs(sz1 == sz2, True)
1356
1357        sz1 = torch.Size([3, j1, 4])
1358        sz2 = torch.Size([3, j2, 4])
1359        self.assertIs(sz1 == sz2, True)
1360        self.assertIs(sz1 != sz2, False)
1361
1362    def test_stride_symnode(self):
1363        from torch._subclasses.fake_tensor import FakeTensorMode
1364
1365        shape_env = ShapeEnv()
1366
1367        def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides):
1368            with FakeTensorMode(shape_env=shape_env) as fake_mode:
1369                return fake_mode.from_tensor(
1370                    x,
1371                    symbolic_context=StatelessSymbolicContext(
1372                        dynamic_sizes=dynamic_sizes,
1373                        dynamic_strides=dynamic_strides,
1374                    ),
1375                )
1376
1377        # check everything static
1378        t = _create_symbolic_tensor(
1379            x=torch.ones(3, 6),
1380            dynamic_sizes=[
1381                DimDynamic.STATIC,
1382                DimDynamic.STATIC,
1383            ],
1384            dynamic_strides=[
1385                DimDynamic.INFER_STRIDE,
1386                DimDynamic.INFER_STRIDE,
1387            ],
1388        )
1389        self.assertTrue(all(isinstance(size, int) for size in t.size()))
1390        self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
1391
1392        # check dynamic size but static dims
1393        t = _create_symbolic_tensor(
1394            x=torch.ones(3, 6),
1395            dynamic_sizes=[
1396                DimDynamic.DYNAMIC,
1397                DimDynamic.DYNAMIC,
1398            ],
1399            dynamic_strides=[
1400                DimDynamic.INFER_STRIDE,
1401                DimDynamic.INFER_STRIDE,
1402            ],
1403        )
1404        # Expect stride to be inferred
1405        s0, s1 = t.size()
1406        s2, s3 = t.stride()
1407        self.assertTrue(isinstance(s0, torch.SymInt))
1408        self.assertTrue(isinstance(s1, torch.SymInt))
1409        self.assertTrue(isinstance(s2, torch.SymInt))
1410        self.assertTrue(s1 == s2)
1411        self.assertEqual(s3, 1)
1412
1413        # Check dynamic stride but static dims
1414        t = _create_symbolic_tensor(
1415            x=torch.ones(3, 6),
1416            dynamic_sizes=[
1417                DimDynamic.STATIC,
1418                DimDynamic.STATIC,
1419            ],
1420            dynamic_strides=[
1421                DimDynamic.DYNAMIC,
1422                DimDynamic.INFER_STRIDE,
1423            ],
1424        )
1425        s0, s1 = t.size()
1426        s2, s3 = t.stride()
1427        self.assertTrue(isinstance(s0, int))
1428        self.assertTrue(isinstance(s1, int))
1429        self.assertTrue(isinstance(s2, torch.SymInt))
1430        self.assertTrue(isinstance(s3, int))
1431
1432        # Check dynamic sizes and dims, and ensure different symbol
1433        t = _create_symbolic_tensor(
1434            x=torch.ones(3, 6),
1435            dynamic_sizes=[
1436                DimDynamic.DYNAMIC,
1437                DimDynamic.DYNAMIC,
1438            ],
1439            dynamic_strides=[
1440                DimDynamic.DYNAMIC,
1441                DimDynamic.INFER_STRIDE,
1442            ],
1443        )
1444        s0, s1 = t.size()
1445        s2, s3 = t.stride()
1446        self.assertTrue(isinstance(s0, torch.SymInt))
1447        self.assertTrue(isinstance(s1, torch.SymInt))
1448        self.assertTrue(isinstance(s2, torch.SymInt))
1449        self.assertTrue(isinstance(s3, int))
1450        self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
1451
1452
1453instantiate_parametrized_tests(TestSymNumberMagicMethods)
1454
1455
1456class TestFloorDiv(TestCase):
1457    @staticmethod
1458    def python_floordiv(x, y):
1459        return x // y
1460
1461    @staticmethod
1462    def torch_floordiv(x, y):
1463        # Note: we fully evaluate here since FloorDiv might not always do
1464        # that.
1465        shape_env = ShapeEnv()
1466        return shape_env.evaluate_expr(FloorDiv(x, y))
1467
1468    @staticmethod
1469    def yield_test_cases(values, negate=True):
1470        for x, y in values:
1471            yield (x, y)
1472            if negate:
1473                yield (-x, y)
1474                yield (x, -y)
1475                yield (-x, -y)
1476
1477    def test_floordiv_float_int(self):
1478        values = ((7, 2),)
1479
1480        for x, y in TestFloorDiv.yield_test_cases(values):
1481            self.assertEqual(
1482                TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
1483            )
1484
1485    def test_floordiv_div_by_one(self):
1486        values = ((2, 1),)
1487
1488        for x, y in TestFloorDiv.yield_test_cases(values):
1489            self.assertEqual(
1490                TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
1491            )
1492
1493    def test_floordiv_simplify(self):
1494        # Tests how we simplify or evaluate FloorDiv without free variables
1495        shape_env = ShapeEnv()
1496        result = 21
1497        exprs = (7 * FloorDiv(6, 2),)
1498
1499        for expr in exprs:
1500            self.assertEqual(expr, result)
1501            self.assertEqual(expr.doit(deep=False), result)
1502            self.assertEqual(expr.doit(deep=True), result)
1503            self.assertEqual(sympy.simplify(expr), result)
1504            self.assertEqual(shape_env.simplify(expr), result)
1505            self.assertEqual(shape_env.evaluate_expr(expr), result)
1506
1507    def test_floordiv_assumptions(self):
1508        cases = (
1509            sympy.Symbol("i1", integer=True),
1510            sympy.Symbol("i2", integer=True),
1511        )
1512
1513        for base, divisor in itertools.product(cases, repeat=2):
1514
1515            def op():
1516                return FloorDiv(base, divisor)
1517
1518            def is_complex(x):
1519                return x.is_integer is False and x.is_real is False and x.is_complex
1520
1521            if is_complex(base) or is_complex(divisor):
1522                self.assertRaisesRegex(
1523                    TypeError,
1524                    (
1525                        r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
1526                        r" expected integer or real"
1527                    ),
1528                    op,
1529                )
1530                continue
1531
1532            op = op()
1533
1534            # In regular Python, x//x == 1.0 if x is a float, but FloorDiv
1535            # always returns an integer 1 when both args are the same object.
1536            # This even works for Symbols with no assumptions specified.
1537            if base is divisor:
1538                self.assertTrue(op.is_integer)
1539                self.assertTrue(op.is_real)
1540            elif base.is_integer and divisor.is_integer:
1541                self.assertTrue(op.is_integer)
1542                self.assertTrue(op.is_real)
1543            else:
1544                self.assertEqual(op.is_integer, None)
1545                self.assertTrue(op.is_real)
1546
1547
1548class TestDimConstraints(TestCase):
1549    def test_dim_constraints_reduce_congruences_simple(self):
1550        from sympy import Symbol
1551
1552        s = Symbol("s", positive=True, integer=True)
1553        dim_constraints = DimConstraints({}, {}, set(), {})
1554        dim_constraints._congruences[s] = {
1555            (s / 2) % 2,
1556            (s / 2) % 8,
1557            (s / 2) % 4,
1558            s % 2,
1559            ((s / 16) + 2) % 4,
1560        }
1561        congruences = dim_constraints._reduce_congruences()
1562        self.assertEqual(congruences[s], {(s + 32) % 64})
1563
1564    def test_dim_constraints_reduce_inequalities_simple(self):
1565        from sympy import Eq, Interval, Ne, Symbol
1566        from sympy.solvers.inequalities import reduce_inequalities
1567
1568        s = Symbol("s", positive=True, integer=True)
1569        exprs = {
1570            s >= 2,
1571            Ne(8 * s, 16),
1572            Ne(s / 2, 1),
1573            Ne(16 * s, 32),
1574            s < 16,
1575            Ne(s, 2),
1576            s / 2 < 16,
1577            s / 2 > 1,
1578            s / 2 >= 2,
1579            Ne(3 * s / 2, 3),
1580        }
1581        solution = reduce_inequalities(exprs, s).as_set()
1582        self.assertEqual(solution, Interval.Ropen(4, 16))
1583
1584        exprs.add(Eq(s / 2, 4))
1585        solution = reduce_inequalities(exprs, s).as_set()
1586        self.assertEqual(solution, {8})
1587
1588    def test_dim_constraints_reduce_inequalities_error(self):
1589        from collections import defaultdict
1590
1591        from sympy import Symbol
1592        from sympy.solvers.inequalities import reduce_inequalities
1593
1594        from torch._dynamo.source import (
1595            LocalSource,
1596            TensorProperty,
1597            TensorPropertySource,
1598        )
1599        from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
1600
1601        s0 = Symbol("s0", positive=True, integer=True)
1602        exprs = {
1603            4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
1604            s0 >= 2,
1605            s0**3 <= 2147483647,
1606            s0 <= 2147483647,
1607        }
1608        answer = reduce_inequalities(exprs, s0)
1609
1610        symbol_to_source = defaultdict(list)
1611        symbol_to_source[s0].append(
1612            TensorPropertySource(
1613                base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
1614            )
1615        )
1616        dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
1617        with self.assertRaisesRegex(
1618            AssertionError,
1619            "Unknown symbol.*created by constraints solver",
1620        ):
1621            dcp.doprint(answer)
1622
1623    def test_dim_constraints_solve_full(self):
1624        from sympy import Eq, Integer, Ne, Symbol
1625
1626        from torch._dynamo.source import (
1627            LocalSource,
1628            TensorProperty,
1629            TensorPropertySource,
1630        )
1631
1632        src0 = TensorPropertySource(
1633            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
1634        )
1635        src2 = TensorPropertySource(
1636            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
1637        )
1638        src3 = TensorPropertySource(
1639            base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
1640        )
1641        src4 = TensorPropertySource(
1642            base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
1643        )
1644
1645        src1 = TensorPropertySource(
1646            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
1647        )
1648        src7 = TensorPropertySource(
1649            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
1650        )
1651
1652        src5 = TensorPropertySource(
1653            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
1654        )
1655        src8 = TensorPropertySource(
1656            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
1657        )
1658
1659        src6 = TensorPropertySource(
1660            base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
1661        )
1662        src9 = TensorPropertySource(
1663            base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
1664        )
1665        src10 = TensorPropertySource(
1666            base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
1667        )
1668
1669        src11 = TensorPropertySource(
1670            base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
1671        )
1672        src12 = TensorPropertySource(
1673            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
1674        )
1675
1676        s0 = Symbol("s0", positive=True, integer=True)
1677        s1 = Symbol("s1", positive=True, integer=True)
1678        s5 = Symbol("s5", positive=True, integer=True)
1679        s6 = Symbol("s6", positive=True, integer=True)
1680        symbol_to_source = {
1681            s0: [src0, src2, src3, src4],
1682            s1: [src1, src7],
1683            s5: [src5, src8],
1684            s6: [src6, src9, src10],
1685        }
1686        var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
1687        marked_dynamic = {s0, s1, s5, s6}
1688        dim_constraints = DimConstraints(
1689            symbol_to_source, var_to_val, marked_dynamic, {}
1690        )
1691        dim_constraints.add_equality(src2, s0)
1692        dim_constraints.add_equality(src3, s0)
1693        dim_constraints.add_equality(src4, s0)
1694        dim_constraints.add_equality(src7, s1)
1695        dim_constraints.add_equality(src8, s5)
1696        dim_constraints.add_equality(src9, s6)
1697        dim_constraints.add_equality(src10, s6)
1698        dim_constraints.add_equality(src11, Integer(1))
1699        dim_constraints.add_equality(src12, Integer(3))
1700
1701        dim_constraints.add(s1**2 <= 2147483647)
1702        dim_constraints.add(32 * s1**2 <= 2147483647)
1703        dim_constraints.add(s0 < 16)
1704        dim_constraints.add(Eq(Mod(s1, 2), 0))
1705        dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
1706        dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
1707        dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
1708        dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
1709        dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
1710        dim_constraints.add(
1711            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1712            + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1713            + 64
1714            <= 2147483647
1715        )
1716        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
1717        dim_constraints.add(
1718            Ne(
1719                (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1720                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1721                + 1,
1722                1,
1723            )
1724        )
1725        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
1726        dim_constraints.add(
1727            (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1728            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1729            + 1
1730            > 1
1731        )
1732        dim_constraints.add(
1733            128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1734            + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1735            + 128
1736            <= 2147483647
1737        )
1738        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
1739        dim_constraints.add(
1740            Ne(
1741                (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1742                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1743                + 1,
1744                1,
1745            )
1746        )
1747        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
1748        dim_constraints.add(
1749            (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1750            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1751            + 1
1752            > 1
1753        )
1754        dim_constraints.add(
1755            256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1756            + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1757            + 256
1758            <= 2147483647
1759        )
1760        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
1761        dim_constraints.add(
1762            Ne(
1763                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1764                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1765                + 1,
1766                1,
1767            )
1768        )
1769        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
1770        dim_constraints.add(
1771            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1772            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1773            + 1
1774            > 1
1775        )
1776        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
1777        dim_constraints.add(
1778            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1779            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1780            + 60
1781            <= 2147483647
1782        )
1783        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
1784        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
1785        dim_constraints.add(
1786            Ne(
1787                60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1788                - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1789                + 60 * s0,
1790                0,
1791            )
1792        )
1793        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
1794        dim_constraints.add(
1795            Ne(
1796                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1797                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1798                + 1,
1799                1,
1800            )
1801        )
1802        dim_constraints.add(
1803            Ne(
1804                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1805                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1806                + 1,
1807                0,
1808            )
1809        )
1810        dim_constraints.add(
1811            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1812            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1813            + 1
1814            >= 0
1815        )
1816        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
1817        dim_constraints.add(
1818            1
1819            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1820            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1821            + 60
1822        )
1823        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
1824        dim_constraints.add(
1825            Ne(
1826                60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1827                - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1828                + 60 * s0,
1829                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1830                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1831                + 120,
1832            )
1833        )
1834        dim_constraints.add(
1835            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1836            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1837            + 120
1838            > 0
1839        )
1840        dim_constraints.add(
1841            Eq(
1842                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
1843                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
1844                + 60 * (Mod(s0, 2)),
1845                0,
1846            )
1847        )
1848        dim_constraints.add(
1849            Ne(
1850                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1851                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1852                + 120,
1853                0,
1854            )
1855        )
1856        dim_constraints.add(
1857            Ne(
1858                60
1859                * (FloorDiv(s0, 2))
1860                * (FloorDiv(s0, (FloorDiv(s0, 2))))
1861                * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1862                - 120
1863                * FloorDiv(s0, 2)
1864                * FloorDiv(s0, (FloorDiv(s0, 2)))
1865                * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1866                + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1867                0,
1868            )
1869        )
1870        dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
1871        dim_constraints.add(
1872            Ne(
1873                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1874                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1875                + 60,
1876                0,
1877            )
1878        )
1879        dim_constraints.add(
1880            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1881            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1882            + 60
1883            >= 0
1884        )
1885        dim_constraints.add(
1886            1
1887            < 60
1888            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1889            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1890            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1891            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
1892        )
1893        dim_constraints.add(Ne(16 * s0, 32))
1894        dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
1895        dim_constraints.add(Ne(16 * s0, 32))
1896        dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
1897        dim_constraints.add(FloorDiv(s0, 2) >= 2)
1898        dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
1899        dim_constraints.add(1 < FloorDiv(s0, 2))
1900        dim_constraints.add(Ne(s0, 2))
1901        dim_constraints.add(
1902            60
1903            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1904            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1905            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1906            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
1907            >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1908            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1909            + 60
1910        )
1911        dim_constraints.add(
1912            60
1913            * (FloorDiv(s0, 2))
1914            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1915            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1916            - 120
1917            * FloorDiv(s0, 2)
1918            * FloorDiv(s0, (FloorDiv(s0, 2)))
1919            * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1920            + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
1921            > 0
1922        )
1923        dim_constraints.add(
1924            Ne(
1925                60
1926                * (FloorDiv(s0, 2))
1927                * (FloorDiv(s0, (FloorDiv(s0, 2))))
1928                * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1929                - 120
1930                * FloorDiv(s0, 2)
1931                * FloorDiv(s0, (FloorDiv(s0, 2)))
1932                * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1933                + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1934                3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1935            )
1936        )
1937        dim_constraints.add(
1938            Ne(
1939                20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1940                - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1941                + 20,
1942                0,
1943            )
1944        )
1945        dim_constraints.add(
1946            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1947            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1948            + 20
1949            >= 0
1950        )
1951        dim_constraints.add(
1952            Ne(
1953                20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1954                - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1955                + 20,
1956                20,
1957            )
1958        )
1959        dim_constraints.add(
1960            Ne(
1961                20
1962                * (
1963                    Mod(
1964                        1,
1965                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1966                        - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1967                        + 1,
1968                    )
1969                ),
1970                0,
1971            )
1972        )
1973        dim_constraints.add(
1974            Ne(
1975                20
1976                * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1977                * (
1978                    Mod(
1979                        1,
1980                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1981                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1982                        - 2
1983                        * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1984                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1985                        + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1986                    )
1987                )
1988                - 20
1989                * Mod(
1990                    1,
1991                    (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1992                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1993                    - 2
1994                    * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1995                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1996                    + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1997                ),
1998                0,
1999            )
2000        )
2001        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
2002        dim_constraints.add(
2003            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2004            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2005            + 1
2006            >= 1
2007        )
2008        dim_constraints.add(
2009            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2010            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2011            + 20
2012            >= 0
2013        )
2014        dim_constraints.add(
2015            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2016            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2017            + 20
2018            >= 1
2019        )
2020        dim_constraints.add(
2021            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2022            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2023            + 20
2024            >= 2
2025        )
2026        dim_constraints.add(
2027            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2028            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2029            + 20
2030            > 1
2031        )
2032        dim_constraints.add(
2033            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2034            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2035            + 20
2036            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2037            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2038            + 60
2039        )
2040        dim_constraints.add(
2041            Ne(
2042                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2043                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2044                + 60,
2045                60,
2046            )
2047        )
2048        dim_constraints.add(
2049            Ne(
2050                FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
2051                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2052                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2053                + 1,
2054            )
2055        )
2056        dim_constraints.add(
2057            Eq(
2058                (FloorDiv((FloorDiv(s1, 2) - 1), 8))
2059                * (
2060                    Mod(
2061                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2062                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2063                        - 2
2064                        * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2065                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2066                        + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
2067                        1,
2068                    )
2069                )
2070                - Mod(
2071                    (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2072                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2073                    - 2
2074                    * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2075                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2076                    + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
2077                    1,
2078                ),
2079                0,
2080            )
2081        )
2082        dim_constraints.add(
2083            Ne(
2084                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2085                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2086                + 1,
2087                FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
2088            )
2089        )
2090        dim_constraints.add(Ne(8 * s0, 16))
2091        dim_constraints.add(
2092            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2093            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2094            + 60
2095            >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2096            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2097            + 1
2098        )
2099        dim_constraints.add(
2100            60
2101            * (FloorDiv(s0, (FloorDiv(s0, 2))))
2102            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2103            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2104            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
2105            <= 2147483647
2106        )
2107        dim_constraints.add(
2108            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2109            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2110            + 90
2111            <= 2147483647
2112        )
2113        dim_constraints.add(FloorDiv(s0, 2) < 16)
2114        dim_constraints.add(FloorDiv(s0, 2) > 1)
2115        dim_constraints.add(
2116            Ne(
2117                90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2118                - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2119                + 90 * (FloorDiv(s0, 2)),
2120                0,
2121            )
2122        )
2123        dim_constraints.add(
2124            1
2125            < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2126            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2127            + 90
2128        )
2129        dim_constraints.add(
2130            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2131            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2132            + 1
2133            > 1
2134        )
2135        dim_constraints.add(
2136            60
2137            * (FloorDiv(s0, (FloorDiv(s0, 2))))
2138            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2139            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2140            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
2141            > 1
2142        )
2143        dim_constraints.add(
2144            Ne(
2145                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2146                - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2147                + 60 * (FloorDiv(s0, 2)),
2148                0,
2149            )
2150        )
2151        dim_constraints.add(
2152            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2153            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2154            + 90
2155            > 1
2156        )
2157        dim_constraints.add(
2158            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2159            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2160            + 60
2161            > 1
2162        )
2163        dim_constraints.add(
2164            Ne(
2165                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2166                - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2167                + 60 * (FloorDiv(s0, 2)),
2168                3 * (FloorDiv(s0, 2)),
2169            )
2170        )
2171        dim_constraints.add(
2172            60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2173            - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2174            + 60 * (FloorDiv(s0, 2))
2175            > 0
2176        )
2177        dim_constraints.add(
2178            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2179            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2180            + 60
2181            > 0
2182        )
2183        dim_constraints.add(
2184            Ne(
2185                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2186                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2187                + 120,
2188                0,
2189            )
2190        )
2191        dim_constraints.add(
2192            1
2193            < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2194            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2195            + 120
2196        )
2197        dim_constraints.add(
2198            Ne(
2199                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2200                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2201                + 120,
2202                6,
2203            )
2204        )
2205        dim_constraints.add(
2206            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2207            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2208            + 120
2209            > 0
2210        )
2211        dim_constraints.add(
2212            Ne(
2213                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2214                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2215                + 120,
2216                0,
2217            )
2218        )
2219        dim_constraints.add(
2220            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2221            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2222            + 120
2223            <= 2147483647
2224        )
2225        dim_constraints.add(
2226            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2227            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2228            + 120
2229            <= 20480
2230        )
2231        dim_constraints.add(
2232            Ne(
2233                90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2234                - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2235                + 90,
2236                0,
2237            )
2238        )
2239        dim_constraints.add(
2240            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2241            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2242            + 120
2243            > 1
2244        )
2245        dim_constraints.add(
2246            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2247            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2248            + 90
2249            <= 20480
2250        )
2251        dim_constraints.add(
2252            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2253            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2254            + 60
2255            <= 20480
2256        )
2257        dim_constraints.add(
2258            Ne(
2259                240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2260                - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2261                + 240,
2262                0,
2263            )
2264        )
2265        dim_constraints.add(Eq(6 * s5, 132))
2266        dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
2267        dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
2268        dim_constraints.add(
2269            Ne(
2270                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2271                - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2272                + 64 * (FloorDiv(s0, 2)),
2273                0,
2274            )
2275        )
2276        dim_constraints.add(
2277            1
2278            < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2279            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2280            + 64
2281        )
2282        dim_constraints.add(
2283            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2284            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2285            + 64
2286            <= 2147483647
2287        )
2288        dim_constraints.add(
2289            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2290            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2291            + 64
2292            > 1
2293        )
2294        dim_constraints.add(
2295            62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2296            - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2297            + 62
2298            <= 2147483647
2299        )
2300        dim_constraints.add(
2301            Ne(
2302                62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2303                - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2304                + 62 * (FloorDiv(s0, 2)),
2305                0,
2306            )
2307        )
2308        dim_constraints.add(
2309            1
2310            < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2311            - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2312            + 62
2313        )
2314        dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
2315        dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
2316        dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
2317        dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
2318        dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
2319        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
2320        dim_constraints.add(
2321            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2322            - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2323            + 576
2324            <= 2147483647
2325        )
2326        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
2327        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
2328        dim_constraints.add(
2329            Ne(
2330                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2331                - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2332                + 576 * (FloorDiv(s0, 2)),
2333                0,
2334            )
2335        )
2336        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
2337        dim_constraints.add(
2338            Ne(
2339                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2340                - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2341                + 9,
2342                1,
2343            )
2344        )
2345        dim_constraints.add(
2346            Ne(
2347                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2348                - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2349                + 9,
2350                0,
2351            )
2352        )
2353        dim_constraints.add(
2354            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2355            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2356            + 9
2357            >= 0
2358        )
2359        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
2360        dim_constraints.add(
2361            1
2362            < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2363            - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2364            + 576
2365        )
2366        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
2367        dim_constraints.add(
2368            Ne(
2369                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2370                - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2371                + 576 * (FloorDiv(s0, 2)),
2372                256,
2373            )
2374        )
2375        dim_constraints.add(
2376            Eq(
2377                64
2378                * (
2379                    Mod(
2380                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2381                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2382                        + 9 * (FloorDiv(s0, 2)),
2383                        4,
2384                    )
2385                ),
2386                0,
2387            )
2388        )
2389        dim_constraints.add(
2390            Eq(
2391                FloorDiv(s0, 2),
2392                FloorDiv(
2393                    (
2394                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2395                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2396                        + 9 * (FloorDiv(s0, 2))
2397                    ),
2398                    4,
2399                ),
2400            )
2401        )
2402        dim_constraints.add(
2403            Eq(
2404                FloorDiv(
2405                    (
2406                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2407                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2408                        + 9 * (FloorDiv(s0, 2))
2409                    ),
2410                    4,
2411                ),
2412                FloorDiv(s0, 2),
2413            )
2414        )
2415        dim_constraints.add(
2416            Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
2417        )
2418        dim_constraints.add(
2419            Eq(
2420                64
2421                * (
2422                    Mod(
2423                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2424                        - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2425                        + 1,
2426                        4,
2427                    )
2428                ),
2429                0,
2430            )
2431        )
2432        dim_constraints.add(
2433            64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2434            - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2435            + 576 * (FloorDiv(s0, 2))
2436            > 0
2437        )
2438        dim_constraints.add(
2439            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2440            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2441            + 9
2442            >= 1
2443        )
2444        dim_constraints.add(
2445            Eq(
2446                64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2447                - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2448                + 576,
2449                256,
2450            )
2451        )
2452        dim_constraints.add(
2453            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2454            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2455            + 540
2456            <= 2147483647
2457        )
2458        dim_constraints.add(
2459            Ne(
2460                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2461                - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2462                + 540 * (FloorDiv(s0, 2)),
2463                0,
2464            )
2465        )
2466        dim_constraints.add(
2467            1
2468            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2469            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2470            + 540
2471        )
2472        dim_constraints.add(
2473            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2474            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2475            + 9
2476            <= 2147483647
2477        )
2478        dim_constraints.add(
2479            Ne(
2480                (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2481                - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2482                + 9 * (FloorDiv(s0, 2)),
2483                0,
2484            )
2485        )
2486        dim_constraints.add(
2487            1
2488            < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2489            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2490            + 9
2491        )
2492        dim_constraints.add(
2493            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2494            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2495            + 9
2496            > 1
2497        )
2498        dim_constraints.add(
2499            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2500            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2501            + 540
2502            > 1
2503        )
2504        dim_constraints.add(s0 >= 2)
2505        dim_constraints.add(s1 >= 2)
2506        dim_constraints.add(s6 >= 2)
2507        dim_constraints.add(s5 >= 2)
2508
2509        dim_constraints.solve()
2510        self.assertEqual(
2511            dim_constraints._static_results,
2512            {
2513                "L['c'].size()[0] == 8",
2514                "L['d'].size()[0] == 8",
2515                "L['a'].size()[2] == 96",
2516                "L['f'].size()[1] == 1",
2517                "L['a'].size()[3] == 96",
2518                "L['b'].size()[2] == 3",
2519                "L['b'].size()[1] == 22",
2520                "L['b'].size()[0] == 8",
2521                "L['a'].size()[1] == 22",
2522                "L['a'].size()[0] == 8",
2523            },
2524        )
2525        self.assertEqual(
2526            dim_constraints._dynamic_results,
2527            {
2528                "2 <= L['c'].size()[1]",
2529                "L['d'].size()[1] == L['c'].size()[1]",
2530                "L['e'].size()[1] == L['c'].size()[1]",
2531            },
2532        )
2533
2534
2535class TestGuardsExpressions(TestCase):
2536    """
2537    Tests the guards-related methods used by the inductor FX graph cache.
2538    """
2539
2540    def test_guards_gt_lt(self):
2541        shape_env = ShapeEnv()
2542        s0 = create_symint(shape_env, 6)
2543        s1 = create_symint(shape_env, 7)
2544        s2 = create_symint(shape_env, 5)
2545
2546        guard_int(sym_int(s0 > 5))
2547        guard_int(sym_int(s0 < 7))
2548
2549        guards = shape_env.produce_guards_expression([s0])
2550
2551        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2552        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
2553        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
2554
2555    def test_guards_float_print(self):
2556        shape_env = ShapeEnv()
2557        s0 = create_symint(shape_env, 3)
2558        guard_bool(2 / s0 == 2 / 3)
2559        guards = shape_env.produce_guards_expression([s0])
2560        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2561
2562    def test_guards_float_div(self):
2563        shape_env = ShapeEnv()
2564        s0 = create_symint(shape_env, 8)
2565        s1 = create_symint(shape_env, 7)
2566
2567        guard_int(sym_int(s0 / 2.0))
2568        guards = shape_env.produce_guards_expression([s0])
2569
2570        self.assertIn("ToFloat", guards)
2571        self.assertIn("FloatTrueDiv", guards)
2572        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2573        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
2574
2575
2576if __name__ == "__main__":
2577    run_tests()
2578