xref: /aosp_15_r20/external/pytorch/test/test_dynamic_shapes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport copy
5*da0073e9SAndroid Build Coastguard Workerimport itertools
6*da0073e9SAndroid Build Coastguard Workerimport math
7*da0073e9SAndroid Build Coastguard Workerimport operator
8*da0073e9SAndroid Build Coastguard Workerimport unittest
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport numpy as np
11*da0073e9SAndroid Build Coastguard Workerimport sympy
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch.fx
15*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
16*da0073e9SAndroid Build Coastguard Workerfrom torch import sym_int, SymBool, SymFloat, SymInt
17*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _disabled_torch_function_impl
18*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental import sym_node
19*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx
20*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
21*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import (
22*da0073e9SAndroid Build Coastguard Worker    _constrain_range_for_size,
23*da0073e9SAndroid Build Coastguard Worker    DimConstraints,
24*da0073e9SAndroid Build Coastguard Worker    DimDynamic,
25*da0073e9SAndroid Build Coastguard Worker    expect_true,
26*da0073e9SAndroid Build Coastguard Worker    guard_bool,
27*da0073e9SAndroid Build Coastguard Worker    guard_float,
28*da0073e9SAndroid Build Coastguard Worker    guard_int,
29*da0073e9SAndroid Build Coastguard Worker    GuardOnDataDependentSymNode,
30*da0073e9SAndroid Build Coastguard Worker    hint_int,
31*da0073e9SAndroid Build Coastguard Worker    is_symbolic,
32*da0073e9SAndroid Build Coastguard Worker    ShapeEnv,
33*da0073e9SAndroid Build Coastguard Worker    StatelessSymbolicContext,
34*da0073e9SAndroid Build Coastguard Worker    statically_known_true,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
37*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
38*da0073e9SAndroid Build Coastguard Worker    parametrize,
39*da0073e9SAndroid Build Coastguard Worker    run_tests,
40*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
41*da0073e9SAndroid Build Coastguard Worker    TestCase,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
45*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._sympy.functions import (
46*da0073e9SAndroid Build Coastguard Worker    FloorDiv,
47*da0073e9SAndroid Build Coastguard Worker    IsNonOverlappingAndDenseIndicator,
48*da0073e9SAndroid Build Coastguard Worker    Mod,
49*da0073e9SAndroid Build Coastguard Worker)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workermeta_funcs = {}
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workerdef register_meta(op):
58*da0073e9SAndroid Build Coastguard Worker    def decorator(f):
59*da0073e9SAndroid Build Coastguard Worker        def add_func(op):
60*da0073e9SAndroid Build Coastguard Worker            meta_funcs[op] = f
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        pytree.tree_map_(add_func, op)
63*da0073e9SAndroid Build Coastguard Worker        return f
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    return decorator
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.add.Tensor, aten.sub.Tensor])
69*da0073e9SAndroid Build Coastguard Workerdef binary_meta(a, b):
70*da0073e9SAndroid Build Coastguard Worker    return a.new_empty(a.shape)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker@register_meta(aten.cat.default)
74*da0073e9SAndroid Build Coastguard Workerdef cat_meta(tensors, dim=0):
75*da0073e9SAndroid Build Coastguard Worker    concat_length = 0
76*da0073e9SAndroid Build Coastguard Worker    shape = tensors[0].shape
77*da0073e9SAndroid Build Coastguard Worker    for tensor in tensors:
78*da0073e9SAndroid Build Coastguard Worker        for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
79*da0073e9SAndroid Build Coastguard Worker            if idx == dim:
80*da0073e9SAndroid Build Coastguard Worker                concat_length = concat_length + length
81*da0073e9SAndroid Build Coastguard Worker            else:
82*da0073e9SAndroid Build Coastguard Worker                assert length == common_length
83*da0073e9SAndroid Build Coastguard Worker    new_shape = list(shape)
84*da0073e9SAndroid Build Coastguard Worker    new_shape[dim] = concat_length
85*da0073e9SAndroid Build Coastguard Worker    return tensors[0].new_empty(new_shape)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.narrow_copy.default])
89*da0073e9SAndroid Build Coastguard Workerdef narrow_copy_symint_meta(a, dim, start, length, **kwargs):
90*da0073e9SAndroid Build Coastguard Worker    shape = []
91*da0073e9SAndroid Build Coastguard Worker    for i, x in enumerate(a.shape):
92*da0073e9SAndroid Build Coastguard Worker        if i == dim:
93*da0073e9SAndroid Build Coastguard Worker            shape.append(length)
94*da0073e9SAndroid Build Coastguard Worker        else:
95*da0073e9SAndroid Build Coastguard Worker            shape.append(x)
96*da0073e9SAndroid Build Coastguard Worker    return a.new_empty(tuple(shape))
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.expand.default])
100*da0073e9SAndroid Build Coastguard Workerdef expand_symint_meta(a, size, implicit=False):
101*da0073e9SAndroid Build Coastguard Worker    return a.new_empty(size)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef create_contiguous(shape):
105*da0073e9SAndroid Build Coastguard Worker    strides = [1]
106*da0073e9SAndroid Build Coastguard Worker    for dim in reversed(shape[:-1]):
107*da0073e9SAndroid Build Coastguard Worker        strides.append(dim * strides[-1])
108*da0073e9SAndroid Build Coastguard Worker    return list(reversed(strides))
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Workerclass FakeSymbolicTensor(torch.Tensor):
112*da0073e9SAndroid Build Coastguard Worker    @staticmethod
113*da0073e9SAndroid Build Coastguard Worker    def __new__(
114*da0073e9SAndroid Build Coastguard Worker        cls,
115*da0073e9SAndroid Build Coastguard Worker        sym_shape,
116*da0073e9SAndroid Build Coastguard Worker        sym_strides,
117*da0073e9SAndroid Build Coastguard Worker        dtype,
118*da0073e9SAndroid Build Coastguard Worker        layout,
119*da0073e9SAndroid Build Coastguard Worker        requires_grad,
120*da0073e9SAndroid Build Coastguard Worker        device,
121*da0073e9SAndroid Build Coastguard Worker        storage_offset=0,
122*da0073e9SAndroid Build Coastguard Worker    ):
123*da0073e9SAndroid Build Coastguard Worker        # TODO: this is wrong in general
124*da0073e9SAndroid Build Coastguard Worker        sym_stride = create_contiguous(sym_shape)
125*da0073e9SAndroid Build Coastguard Worker        r = torch.Tensor._make_wrapper_subclass(
126*da0073e9SAndroid Build Coastguard Worker            cls,
127*da0073e9SAndroid Build Coastguard Worker            sym_shape,
128*da0073e9SAndroid Build Coastguard Worker            sym_stride,
129*da0073e9SAndroid Build Coastguard Worker            storage_offset,
130*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
131*da0073e9SAndroid Build Coastguard Worker            layout=layout,
132*da0073e9SAndroid Build Coastguard Worker            requires_grad=requires_grad,
133*da0073e9SAndroid Build Coastguard Worker            device=device,
134*da0073e9SAndroid Build Coastguard Worker        )
135*da0073e9SAndroid Build Coastguard Worker        return r
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    __torch_function__ = _disabled_torch_function_impl
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker    def new_empty(self, shape):
140*da0073e9SAndroid Build Coastguard Worker        return FakeSymbolicTensor(
141*da0073e9SAndroid Build Coastguard Worker            shape, None, self.dtype, self.layout, self.requires_grad, self.device
142*da0073e9SAndroid Build Coastguard Worker        )
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @classmethod
145*da0073e9SAndroid Build Coastguard Worker    def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
146*da0073e9SAndroid Build Coastguard Worker        if func_overload in meta_funcs:
147*da0073e9SAndroid Build Coastguard Worker            return meta_funcs[func_overload](*args, **kwargs)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        if func_overload == torch.ops.aten.new_empty.default:
150*da0073e9SAndroid Build Coastguard Worker            self = args[0]
151*da0073e9SAndroid Build Coastguard Worker            shape = args[1]
152*da0073e9SAndroid Build Coastguard Worker            return FakeSymbolicTensor(
153*da0073e9SAndroid Build Coastguard Worker                shape,
154*da0073e9SAndroid Build Coastguard Worker                self.stride(),
155*da0073e9SAndroid Build Coastguard Worker                self.dtype,
156*da0073e9SAndroid Build Coastguard Worker                self.layout,
157*da0073e9SAndroid Build Coastguard Worker                self.requires_grad,
158*da0073e9SAndroid Build Coastguard Worker                self.device,
159*da0073e9SAndroid Build Coastguard Worker            )
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"operator {func_overload} not supported")
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Workerdef create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
165*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.source import ConstantSource
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    if source is None:
168*da0073e9SAndroid Build Coastguard Worker        source = ConstantSource(name)
169*da0073e9SAndroid Build Coastguard Worker    constraint_dims = [None] * arg.dim()
170*da0073e9SAndroid Build Coastguard Worker    if dynamic_dims is None:
171*da0073e9SAndroid Build Coastguard Worker        dynamic_dims = [DimDynamic.DUCK] * arg.dim()
172*da0073e9SAndroid Build Coastguard Worker    (
173*da0073e9SAndroid Build Coastguard Worker        sym_shapes,
174*da0073e9SAndroid Build Coastguard Worker        sym_strides,
175*da0073e9SAndroid Build Coastguard Worker        sym_storage_offset,
176*da0073e9SAndroid Build Coastguard Worker    ) = shape_env.create_symbolic_sizes_strides_storage_offset(
177*da0073e9SAndroid Build Coastguard Worker        arg,
178*da0073e9SAndroid Build Coastguard Worker        source=source,
179*da0073e9SAndroid Build Coastguard Worker        symbolic_context=StatelessSymbolicContext(
180*da0073e9SAndroid Build Coastguard Worker            dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
181*da0073e9SAndroid Build Coastguard Worker        ),
182*da0073e9SAndroid Build Coastguard Worker    )
183*da0073e9SAndroid Build Coastguard Worker    return FakeSymbolicTensor(
184*da0073e9SAndroid Build Coastguard Worker        sym_shapes,
185*da0073e9SAndroid Build Coastguard Worker        sym_strides,
186*da0073e9SAndroid Build Coastguard Worker        arg.dtype,
187*da0073e9SAndroid Build Coastguard Worker        arg.layout,
188*da0073e9SAndroid Build Coastguard Worker        arg.requires_grad,
189*da0073e9SAndroid Build Coastguard Worker        arg.device,
190*da0073e9SAndroid Build Coastguard Worker        sym_storage_offset,
191*da0073e9SAndroid Build Coastguard Worker    )
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerdef create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
195*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.source import ConstantSource
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    symbol = shape_env.create_symbol(
198*da0073e9SAndroid Build Coastguard Worker        val,
199*da0073e9SAndroid Build Coastguard Worker        source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
200*da0073e9SAndroid Build Coastguard Worker        dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
201*da0073e9SAndroid Build Coastguard Worker        constraint_dim=None,
202*da0073e9SAndroid Build Coastguard Worker        **kwargs,
203*da0073e9SAndroid Build Coastguard Worker    )
204*da0073e9SAndroid Build Coastguard Worker    return cls(SymNode(symbol, shape_env, pytype, hint=val))
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker# TODO: default duck to False
208*da0073e9SAndroid Build Coastguard Workerdef create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
209*da0073e9SAndroid Build Coastguard Worker    return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Workerdef create_symbool(shape_env, b: bool) -> SymBool:
213*da0073e9SAndroid Build Coastguard Worker    return create_symtype(SymBool, bool, shape_env, b)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerdef create_symfloat(shape_env, f: float) -> SymFloat:
217*da0073e9SAndroid Build Coastguard Worker    return create_symtype(SymFloat, float, shape_env, f)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo(
221*da0073e9SAndroid Build Coastguard Worker    "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
222*da0073e9SAndroid Build Coastguard Worker)
223*da0073e9SAndroid Build Coastguard Workerclass TestPySymInt(TestCase):
224*da0073e9SAndroid Build Coastguard Worker    def test_arith_ops(self):
225*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
226*da0073e9SAndroid Build Coastguard Worker        symints = []
227*da0073e9SAndroid Build Coastguard Worker        for i in range(2, 5):
228*da0073e9SAndroid Build Coastguard Worker            symints.append((i, create_symint(shape_env, i)))
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        ops = [
231*da0073e9SAndroid Build Coastguard Worker            operator.add,
232*da0073e9SAndroid Build Coastguard Worker            operator.sub,
233*da0073e9SAndroid Build Coastguard Worker            operator.floordiv,
234*da0073e9SAndroid Build Coastguard Worker            operator.mul,
235*da0073e9SAndroid Build Coastguard Worker            operator.mod,
236*da0073e9SAndroid Build Coastguard Worker        ]
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        for op in ops:
239*da0073e9SAndroid Build Coastguard Worker            for args in itertools.permutations(symints, 2):
240*da0073e9SAndroid Build Coastguard Worker                if not isinstance(args[0][1], int) and (
241*da0073e9SAndroid Build Coastguard Worker                    (op != operator.mod or op != operator.floordiv) and args[1][0] != 0
242*da0073e9SAndroid Build Coastguard Worker                ):
243*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
244*da0073e9SAndroid Build Coastguard Worker                        op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
245*da0073e9SAndroid Build Coastguard Worker                    )
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    def test_reverse_arith_ops(self):
248*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        a = create_symint(shape_env, 2)
251*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(5 // a == 5 // 2)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        a = create_symint(shape_env, 2)
254*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(5 * a == 5 * 2)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def test_sympify_symint(self):
257*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
258*da0073e9SAndroid Build Coastguard Worker        a = create_symint(shape_env, 2)
259*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sympy.sympify(a), a.node.expr)
260*da0073e9SAndroid Build Coastguard Worker        b = create_symfloat(shape_env, 3.0)
261*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sympy.sympify(b), b.node.expr)
262*da0073e9SAndroid Build Coastguard Worker        c = create_symbool(shape_env, True)
263*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sympy.sympify(c), c.node.expr)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_roundtrip(self):
266*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
267*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not isinstance(x.shape[0], SymNode))
270*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(x.shape[0], SymInt))
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.shape[0] == 5)
273*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.shape[1] == 4)
274*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.shape[2], 3)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size()[0], 5)
277*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size()[1], 4)
278*da0073e9SAndroid Build Coastguard Worker        # Should be simplifiable to an integer.
279*da0073e9SAndroid Build Coastguard Worker        # Ref: https://github.com/pytorch/pytorch/pull/107492
280*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(x.size()[1], SymInt))
281*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
282*da0073e9SAndroid Build Coastguard Worker            isinstance(x.size()[1].node.maybe_as_int(), int)
283*da0073e9SAndroid Build Coastguard Worker        )  # due to guard above
284*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size()[2] == 3)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size(0) == 5)
287*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size(1) == 4)
288*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.size(2) == 3)
289*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(x.size(2), SymInt))
290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
293*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(y.storage_offset(), SymInt))
294*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.storage_offset() == 12)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def test_binary(self):
297*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
298*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
299*da0073e9SAndroid Build Coastguard Worker        y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        z = x + y
302*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
303*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
304*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        # broadcasting
307*da0073e9SAndroid Build Coastguard Worker        y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
308*da0073e9SAndroid Build Coastguard Worker        z = x + y
309*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
310*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
311*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def test_symint_args(self):
314*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
315*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
316*da0073e9SAndroid Build Coastguard Worker        y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
317*da0073e9SAndroid Build Coastguard Worker        LAST_DIM = 2
318*da0073e9SAndroid Build Coastguard Worker        z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
319*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == y.shape[2])
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        # arithmetic expr with two symints
322*da0073e9SAndroid Build Coastguard Worker        z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
323*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 2)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        # arithmetic expr with a symint and python int
326*da0073e9SAndroid Build Coastguard Worker        z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
327*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 2)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def test_symint_vargs(self):
330*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
331*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
332*da0073e9SAndroid Build Coastguard Worker        y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        # varargs
335*da0073e9SAndroid Build Coastguard Worker        z = y.expand(x.shape[0], y.shape[1], x.shape[2])
336*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
337*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
338*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        # shape list
341*da0073e9SAndroid Build Coastguard Worker        z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
342*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
343*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        # mixed python symints and ints
347*da0073e9SAndroid Build Coastguard Worker        z = y.expand(x.shape[0], y.shape[1], 3)
348*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
349*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
350*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        # mixed python symints and ints in a list
353*da0073e9SAndroid Build Coastguard Worker        z = y.expand((x.shape[0], y.shape[1], 3))
354*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
355*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
356*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        # mixed python symints and ints
359*da0073e9SAndroid Build Coastguard Worker        z = y.expand(5, y.shape[1], x.shape[2])
360*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
361*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
362*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        # mixed python ints and symints in a list
365*da0073e9SAndroid Build Coastguard Worker        z = y.expand((5, y.shape[1], x.shape[2]))
366*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[0] == 5)
367*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[1] == 4)
368*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.shape[2] == 3)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        z = y.expand((y.shape[1],))
371*da0073e9SAndroid Build Coastguard Worker        z = y.expand(y.shape[1])
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    def test_stride(self):
374*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
375*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
376*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.stride()[0], SymInt)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_size_expressions(self):
379*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
380*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
381*da0073e9SAndroid Build Coastguard Worker        expand_x = x.expand(x.shape[0], x.shape[0])
382*da0073e9SAndroid Build Coastguard Worker        if expand_x.shape[0] > 3:
383*da0073e9SAndroid Build Coastguard Worker            result = expand_x + expand_x
384*da0073e9SAndroid Build Coastguard Worker        else:
385*da0073e9SAndroid Build Coastguard Worker            result = expand_x + expand_x
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        gt_op, _bt = shape_env.guards[-1]
388*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
389*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
390*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
391*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    def test_floordiv_static(self):
394*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
395*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 8)
396*da0073e9SAndroid Build Coastguard Worker        # This was extracted from
397*da0073e9SAndroid Build Coastguard Worker        # python test/inductor/test_cuda_cpp_wrapper.py -k
398*da0073e9SAndroid Build Coastguard Worker        # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
399*da0073e9SAndroid Build Coastguard Worker        bool(s0 % 2 == 0)
400*da0073e9SAndroid Build Coastguard Worker        bool(s0 % (s0 // 2) == 0)
401*da0073e9SAndroid Build Coastguard Worker        bool(2 * (s0 // 2) == s0)
402*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def test_numel(self):
405*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
406*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
407*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.numel(), torch.SymInt)
408*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.numel(x), torch.SymInt)
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 3)
411*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.numel(), int)
412*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(torch.numel(x), int)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    def test_int_to_float(self):
415*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
416*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
417*da0073e9SAndroid Build Coastguard Worker        r = torch.sym_float(x.shape[0])
418*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymFloat, msg=type(r))
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def test_aten_ops(self):
421*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
422*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x", torch.randn(5), shape_env)
423*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
426*da0073e9SAndroid Build Coastguard Worker        x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
427*da0073e9SAndroid Build Coastguard Worker        torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    def test_fx_trace_intlist(self):
430*da0073e9SAndroid Build Coastguard Worker        class CustomModule(torch.nn.Module):
431*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
432*da0073e9SAndroid Build Coastguard Worker                bs, c, h, w = x.shape
433*da0073e9SAndroid Build Coastguard Worker                return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker        m = CustomModule()
436*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 3, 4, 4)
437*da0073e9SAndroid Build Coastguard Worker        # should not TypeError: pad(): argument 'pad' (position 2) must be
438*da0073e9SAndroid Build Coastguard Worker        # tuple of ints, not tuple
439*da0073e9SAndroid Build Coastguard Worker        torch.fx.symbolic_trace(m)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    def test_meta_symint(self):
442*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
443*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 2)
444*da0073e9SAndroid Build Coastguard Worker        r = torch.empty(a0, device="meta")
445*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r.shape[0], SymInt)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    def test_guard_int(self):
448*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
449*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 2)
450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(guard_int(a0), 2)
451*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    def test_prefer_deferred_runtime_assertions_over_guards(self):
454*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
455*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 2)
456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(guard_int(s0), 2)
457*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
460*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 2)
461*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(s0 == 2))
462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
463*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
464*da0073e9SAndroid Build Coastguard Worker            str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
465*da0073e9SAndroid Build Coastguard Worker            """[Eq(s0, 2)]""",
466*da0073e9SAndroid Build Coastguard Worker        )
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    def test_sym_int(self):
469*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
470*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
471*da0073e9SAndroid Build Coastguard Worker        r = sym_int(a0)
472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 5)
473*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
474*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker        a1 = create_symint(shape_env, 7)
477*da0073e9SAndroid Build Coastguard Worker        r = sym_int(a1 / 2)
478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(guard_int(r), 3)
479*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
480*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
481*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
482*da0073e9SAndroid Build Coastguard Worker        )
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker        a3 = create_symint(shape_env, 3)
485*da0073e9SAndroid Build Coastguard Worker        r = sym_int(2.0 * torch.sym_float(a3))
486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(guard_int(r), 6)
487*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
488*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
489*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
490*da0073e9SAndroid Build Coastguard Worker        )
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker    def test_sym_sqrt(self):
493*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
494*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 4)
495*da0073e9SAndroid Build Coastguard Worker        r = torch._sym_sqrt(a0)
496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 2)
497*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymFloat, msg=type(r))
498*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
499*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
500*da0073e9SAndroid Build Coastguard Worker        )
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker    def test_sym_floor(self):
503*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
504*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
505*da0073e9SAndroid Build Coastguard Worker        r = math.floor(a0 / 2)
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 2)
507*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
508*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
509*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[0][0]),
510*da0073e9SAndroid Build Coastguard Worker            """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
511*da0073e9SAndroid Build Coastguard Worker        )
512*da0073e9SAndroid Build Coastguard Worker        r = math.floor(3.0 * a0)
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 15)
514*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
515*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
516*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[1][0]),
517*da0073e9SAndroid Build Coastguard Worker            """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
518*da0073e9SAndroid Build Coastguard Worker        )
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    def test_sym_trunc(self):
521*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
522*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
523*da0073e9SAndroid Build Coastguard Worker        r = math.trunc(a0 / 2)
524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 2)
525*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
526*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
527*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
528*da0073e9SAndroid Build Coastguard Worker        )
529*da0073e9SAndroid Build Coastguard Worker        r = torch.sym_int(torch.sym_sqrt(a0))
530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 2)
531*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
532*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
533*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
534*da0073e9SAndroid Build Coastguard Worker        )
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    def test_sym_ceil(self):
537*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
538*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
539*da0073e9SAndroid Build Coastguard Worker        r = math.ceil(a0 / 2)
540*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 3)
541*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
542*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
543*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[0][0]),
544*da0073e9SAndroid Build Coastguard Worker            """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
545*da0073e9SAndroid Build Coastguard Worker        )
546*da0073e9SAndroid Build Coastguard Worker        r1 = 3.0 * a0
547*da0073e9SAndroid Build Coastguard Worker        r = math.floor(r1)
548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, 15)
549*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r, torch.SymInt, msg=type(r))
550*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
551*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[1][0]),
552*da0073e9SAndroid Build Coastguard Worker            """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
553*da0073e9SAndroid Build Coastguard Worker        )
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    def test_sym_ite(self):
556*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
557*da0073e9SAndroid Build Coastguard Worker        t = create_symint(shape_env, 5)
558*da0073e9SAndroid Build Coastguard Worker        f = create_symint(shape_env, 4)
559*da0073e9SAndroid Build Coastguard Worker        b1 = True
560*da0073e9SAndroid Build Coastguard Worker        r1 = torch.sym_ite(b1, t, f)
561*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r1 is t)
562*da0073e9SAndroid Build Coastguard Worker        b2 = False
563*da0073e9SAndroid Build Coastguard Worker        r2 = torch.sym_ite(b2, t, f)
564*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(r2 is f)
565*da0073e9SAndroid Build Coastguard Worker        b3 = t == 5
566*da0073e9SAndroid Build Coastguard Worker        r3 = torch.sym_ite(b3, t, f)
567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r3, 5)
569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(t), type(r3))
570*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
571*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[0][0]),
572*da0073e9SAndroid Build Coastguard Worker            """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
573*da0073e9SAndroid Build Coastguard Worker        )
574*da0073e9SAndroid Build Coastguard Worker        b4 = f == 5
575*da0073e9SAndroid Build Coastguard Worker        r4 = torch.sym_ite(b4, t, f)
576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 1)
577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r4, 4)
578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(f), type(r4))
579*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
580*da0073e9SAndroid Build Coastguard Worker            str(shape_env.guards[1][0]),
581*da0073e9SAndroid Build Coastguard Worker            """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
582*da0073e9SAndroid Build Coastguard Worker        )
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    def test_tracing_sym_ite(self):
585*da0073e9SAndroid Build Coastguard Worker        def f(x):
586*da0073e9SAndroid Build Coastguard Worker            b = x.shape[0] == 5
587*da0073e9SAndroid Build Coastguard Worker            ret = torch.sym_ite(b, x.shape[0], x.shape[1])
588*da0073e9SAndroid Build Coastguard Worker            return ret
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(gm.shape_env.guards), 0)
592*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
593*da0073e9SAndroid Build Coastguard Worker            gm.code.strip(),
594*da0073e9SAndroid Build Coastguard Worker            """\
595*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1):
596*da0073e9SAndroid Build Coastguard Worker    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
597*da0073e9SAndroid Build Coastguard Worker    eq = sym_size_int == 5
598*da0073e9SAndroid Build Coastguard Worker    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1);  x_1 = None
599*da0073e9SAndroid Build Coastguard Worker    sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1);  eq = sym_size_int = sym_size_int_1 = None
600*da0073e9SAndroid Build Coastguard Worker    return sym_ite""",
601*da0073e9SAndroid Build Coastguard Worker        )
602*da0073e9SAndroid Build Coastguard Worker        r1 = gm(torch.ones(4, 5))
603*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r1, int)
604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, 5)
605*da0073e9SAndroid Build Coastguard Worker        r2 = gm(torch.ones(5, 4))
606*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(r2, int)
607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r2, 5)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    def test_int_conversion(self):
610*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
611*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 2)
612*da0073e9SAndroid Build Coastguard Worker        int(a0)
613*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_guard(self):
616*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
617*da0073e9SAndroid Build Coastguard Worker        s0 = shape_env.create_unbacked_symint()
618*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    def test_data_dependent_guard_propagate_real_tensors(self):
621*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
622*da0073e9SAndroid Build Coastguard Worker        s0 = shape_env.create_unbacked_symint()
623*da0073e9SAndroid Build Coastguard Worker        shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bool(s0 == 0), True)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    def test_expect_true_basic(self):
627*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
628*da0073e9SAndroid Build Coastguard Worker        i0 = shape_env.create_unbacked_symint()
629*da0073e9SAndroid Build Coastguard Worker        i0_sym = i0.node.expr
630*da0073e9SAndroid Build Coastguard Worker        # This doesn't error
631*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i0 == 0))
632*da0073e9SAndroid Build Coastguard Worker        # This generates a deferred runtime assert via replacement
633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape_env.replacements[i0_sym], 0)
634*da0073e9SAndroid Build Coastguard Worker        # After expecting true, guards now resolve given the runtime assert
635*da0073e9SAndroid Build Coastguard Worker        bool(i0 == 0)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    def test_expect_true_with_s0(self):
638*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
639*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 5)
640*da0073e9SAndroid Build Coastguard Worker        i0 = shape_env.create_unbacked_symint()
641*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i0 < s0))
642*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
643*da0073e9SAndroid Build Coastguard Worker            str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
644*da0073e9SAndroid Build Coastguard Worker            """[u0 < s0]""",
645*da0073e9SAndroid Build Coastguard Worker        )
646*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(i0 < s0)
647*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(i0 != s0)
648*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(i0 > s0)
649*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(i0 >= s0)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    def test_expect_true_prefer_later(self):
652*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
653*da0073e9SAndroid Build Coastguard Worker        i0 = shape_env.create_unbacked_symint()
654*da0073e9SAndroid Build Coastguard Worker        i1 = shape_env.create_unbacked_symint()
655*da0073e9SAndroid Build Coastguard Worker        i1_sym = i1.node.expr
656*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i0 + i1 == 10))
657*da0073e9SAndroid Build Coastguard Worker        # Importantly, this is put in i1, not i0!
658*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
659*da0073e9SAndroid Build Coastguard Worker            str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
660*da0073e9SAndroid Build Coastguard Worker            """[Eq(u0 + u1, 10)]""",
661*da0073e9SAndroid Build Coastguard Worker        )
662*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(i0 + i1 == 10)
663*da0073e9SAndroid Build Coastguard Worker        # NB: We currently don't support deriving that we can substitute
664*da0073e9SAndroid Build Coastguard Worker        # i0 + i1 with 10; maybe we should, but this means our rewriting
665*da0073e9SAndroid Build Coastguard Worker        # system is no longer confluent (it's probably OK though, because
666*da0073e9SAndroid Build Coastguard Worker        # you're unlikely to get other equalities like this on the
667*da0073e9SAndroid Build Coastguard Worker        # unbacked SymInts.)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_unbacked_substitution(self):
670*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
671*da0073e9SAndroid Build Coastguard Worker        i0 = shape_env.create_unbacked_symint()
672*da0073e9SAndroid Build Coastguard Worker        i1 = shape_env.create_unbacked_symint()
673*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i0)
674*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i1)
675*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i0 == i1 * 4))
676*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(i0), """u0""")
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        i2 = shape_env.create_unbacked_symint()
679*da0073e9SAndroid Build Coastguard Worker        i3 = shape_env.create_unbacked_symint()
680*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i2)
681*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i3)
682*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i2 * 4 == i3))
683*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(i3), """u3""")
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker    def test_avoid_unbacked_substitution(self):
686*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
687*da0073e9SAndroid Build Coastguard Worker        i0 = shape_env.create_unbacked_symint()
688*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i0)
689*da0073e9SAndroid Build Coastguard Worker        i1 = shape_env.create_unbacked_symint()
690*da0073e9SAndroid Build Coastguard Worker        _constrain_range_for_size(i1)
691*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(i0 == 10 - i1))
692*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(str(i0), """u0""")
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    def test_expect_true_double_digits(self):
695*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
696*da0073e9SAndroid Build Coastguard Worker        ia = [shape_env.create_unbacked_symint() for _ in range(11)]  # allocate 10
697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(ia[-1]), "u10")
698*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(expect_true(sum(ia) == 20))
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    def test_expect_true_refine_range(self):
702*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
703*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
704*da0073e9SAndroid Build Coastguard Worker            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
705*da0073e9SAndroid Build Coastguard Worker        ):
706*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
707*da0073e9SAndroid Build Coastguard Worker                i0 = shape_env.create_unbacked_symint()
708*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(expect_true(rel(i0)))
709*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 3))
710*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 4))
711*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 5))
712*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 6))
713*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 > 4))
714*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 >= 5))
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
717*da0073e9SAndroid Build Coastguard Worker            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
718*da0073e9SAndroid Build Coastguard Worker        ):
719*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
720*da0073e9SAndroid Build Coastguard Worker                i0 = shape_env.create_unbacked_symint()
721*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(expect_true(rel(i0)))
722*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 2))
723*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 3))
724*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 4))
725*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 5))
726*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 < 4))
727*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 <= 5))
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker    def test_guard_refine_range(self):
730*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
731*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
732*da0073e9SAndroid Build Coastguard Worker            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
733*da0073e9SAndroid Build Coastguard Worker        ):
734*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
735*da0073e9SAndroid Build Coastguard Worker                i0 = create_symint(shape_env, 10, duck=False)
736*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(bool(rel(i0)))
737*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 3))
738*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 4))
739*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 5))
740*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 6))
741*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 > 4))
742*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 >= 5))
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
745*da0073e9SAndroid Build Coastguard Worker            [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
746*da0073e9SAndroid Build Coastguard Worker        ):
747*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
748*da0073e9SAndroid Build Coastguard Worker                i0 = create_symint(shape_env, 2, duck=False)
749*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(bool(rel(i0)))
750*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 3))
751*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 4))
752*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 5))
753*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 6))
754*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 <= 4))
755*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 < 5))
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
758*da0073e9SAndroid Build Coastguard Worker            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
759*da0073e9SAndroid Build Coastguard Worker        ):
760*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
761*da0073e9SAndroid Build Coastguard Worker                i0 = create_symint(shape_env, 2, duck=False)
762*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(bool(rel(i0)))
763*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 2))
764*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 3))
765*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 4))
766*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 5))
767*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 < 4))
768*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 <= 3))
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker        for i, rel in enumerate(
771*da0073e9SAndroid Build Coastguard Worker            [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
772*da0073e9SAndroid Build Coastguard Worker        ):
773*da0073e9SAndroid Build Coastguard Worker            with self.subTest(f"i = {i}"):
774*da0073e9SAndroid Build Coastguard Worker                i0 = create_symint(shape_env, 10, duck=False)
775*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(bool(rel(i0)))
776*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 2))
777*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 != 3))
778*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 4))
779*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(statically_known_true(i0 != 5))
780*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 >= 4))
781*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(statically_known_true(i0 > 3))
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    def test_mul_int_oo_nan(self):
784*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
785*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 5, duck=False)
786*da0073e9SAndroid Build Coastguard Worker        s1 = create_symint(shape_env, 6, duck=False)
787*da0073e9SAndroid Build Coastguard Worker        s2 = create_symint(shape_env, 5, duck=False)
788*da0073e9SAndroid Build Coastguard Worker        bool(s0 * (s1 // s0) == s2)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker    def test_non_overlapping_and_dense(self):
791*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
792*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
793*da0073e9SAndroid Build Coastguard Worker        r = torch.empty_strided((a0, 7), (1, a0), device="meta")
794*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker    def test_non_overlapping_and_dense_unbacked(self):
797*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
798*da0073e9SAndroid Build Coastguard Worker        u0 = shape_env.create_unbacked_symint()
799*da0073e9SAndroid Build Coastguard Worker        torch._check_is_size(u0)
800*da0073e9SAndroid Build Coastguard Worker        cf = torch.ops.aten.is_non_overlapping_and_dense.default
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
804*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
805*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
808*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
809*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
810*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker        Max = torch.sym_max
813*da0073e9SAndroid Build Coastguard Worker        # NB: This only works because we're able to determine this tensor is
814*da0073e9SAndroid Build Coastguard Worker        # contiguous. transpose(0, 1) makes it stop working
815*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
816*da0073e9SAndroid Build Coastguard Worker            cf(
817*da0073e9SAndroid Build Coastguard Worker                torch.empty_strided(
818*da0073e9SAndroid Build Coastguard Worker                    (2, 3, 1, u0),
819*da0073e9SAndroid Build Coastguard Worker                    (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
820*da0073e9SAndroid Build Coastguard Worker                    device="meta",
821*da0073e9SAndroid Build Coastguard Worker                )
822*da0073e9SAndroid Build Coastguard Worker            )
823*da0073e9SAndroid Build Coastguard Worker        )
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    def test_numpy_sym_max(self):
826*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
828*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
830*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    def test_numpy_sym_min(self):
834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
839*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    def test_debug_has_internal_overlap_unbacked(self):
842*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
843*da0073e9SAndroid Build Coastguard Worker        u0 = shape_env.create_unbacked_symint()
844*da0073e9SAndroid Build Coastguard Worker        torch._check_is_size(u0)
845*da0073e9SAndroid Build Coastguard Worker        cf = torch._debug_has_internal_overlap
846*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
847*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
850*da0073e9SAndroid Build Coastguard Worker        Max = torch.sym_max
851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
852*da0073e9SAndroid Build Coastguard Worker            cf(
853*da0073e9SAndroid Build Coastguard Worker                torch.empty_strided(
854*da0073e9SAndroid Build Coastguard Worker                    (2, 3, 1, u0),
855*da0073e9SAndroid Build Coastguard Worker                    (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
856*da0073e9SAndroid Build Coastguard Worker                    device="meta",
857*da0073e9SAndroid Build Coastguard Worker                )
858*da0073e9SAndroid Build Coastguard Worker            ),
859*da0073e9SAndroid Build Coastguard Worker            0,
860*da0073e9SAndroid Build Coastguard Worker        )
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker        # Wobbling these to zero is OK too
863*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
864*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker    def test_specialize_zero_one(self):
867*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(specialize_zero_one=True)
868*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
869*da0073e9SAndroid Build Coastguard Worker        assert a0 != 1
870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(specialize_zero_one=False)
873*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
874*da0073e9SAndroid Build Coastguard Worker        assert a0 != 1
875*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 1)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker    def test_duck_shape(self):
878*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(duck_shape=True)
879*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
880*da0073e9SAndroid Build Coastguard Worker        a1 = create_symint(shape_env, 5)
881*da0073e9SAndroid Build Coastguard Worker        assert a0 == a1
882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(duck_shape=False)
885*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
886*da0073e9SAndroid Build Coastguard Worker        a1 = create_symint(shape_env, 5)
887*da0073e9SAndroid Build Coastguard Worker        assert a0 == a1
888*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 1)
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    def test_int_bool(self):
891*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/95981
892*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv(duck_shape=True)
893*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 5)
894*da0073e9SAndroid Build Coastguard Worker        assert a0
895*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker    def test_symint_as_scalar(self):
898*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
899*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 2)
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        sym_int_encountered = False
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        class TestSymInt(TorchDispatchMode):
904*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
905*da0073e9SAndroid Build Coastguard Worker                assert func == torch.ops.aten.add.Tensor
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker                nonlocal sym_int_encountered
908*da0073e9SAndroid Build Coastguard Worker                # WARNING: do not do identity tests on the outer
909*da0073e9SAndroid Build Coastguard Worker                # SymInt/SymFloat, they are NOT STABLE
910*da0073e9SAndroid Build Coastguard Worker                sym_int_encountered = kwargs["alpha"].node is a0.node
911*da0073e9SAndroid Build Coastguard Worker                kwargs["alpha"] = 0
912*da0073e9SAndroid Build Coastguard Worker                return func(*args)
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([4, 4])
915*da0073e9SAndroid Build Coastguard Worker        with TestSymInt():
916*da0073e9SAndroid Build Coastguard Worker            y = torch.add(x, x, alpha=a0)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sym_int_encountered)
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy(self):
921*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
922*da0073e9SAndroid Build Coastguard Worker        a0 = create_symint(shape_env, 2)
923*da0073e9SAndroid Build Coastguard Worker        assert a0 < 4
924*da0073e9SAndroid Build Coastguard Worker        new_shape_env = copy.deepcopy(shape_env)
925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(new_shape_env.guards), 1)
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    def test_print_readable_with_symints(self):
928*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
929*da0073e9SAndroid Build Coastguard Worker            dim0 = a.shape[0] + b.shape[0]
930*da0073e9SAndroid Build Coastguard Worker            dim1 = a.shape[1] + b.shape[1]
931*da0073e9SAndroid Build Coastguard Worker            d = a.new_empty(dim0, dim1)
932*da0073e9SAndroid Build Coastguard Worker            d = torch.ops.aten.native_dropout(d, 0.5, train=True)
933*da0073e9SAndroid Build Coastguard Worker            return d
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
936*da0073e9SAndroid Build Coastguard Worker        out = fx_g.print_readable(print_output=False)
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
939*da0073e9SAndroid Build Coastguard Worker            out.strip(),
940*da0073e9SAndroid Build Coastguard Worker            """\
941*da0073e9SAndroid Build Coastguard Workerclass f(torch.nn.Module):
942*da0073e9SAndroid Build Coastguard Worker    def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
943*da0073e9SAndroid Build Coastguard Worker        # No stacktrace found for following nodes
944*da0073e9SAndroid Build Coastguard Worker        sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
945*da0073e9SAndroid Build Coastguard Worker        sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
946*da0073e9SAndroid Build Coastguard Worker        add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1;  sym_size_int = sym_size_int_1 = None
947*da0073e9SAndroid Build Coastguard Worker        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
948*da0073e9SAndroid Build Coastguard Worker        sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1);  b_1 = None
949*da0073e9SAndroid Build Coastguard Worker        add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3;  sym_size_int_2 = sym_size_int_3 = None
950*da0073e9SAndroid Build Coastguard Worker        new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False);  a_1 = add = add_1 = None
951*da0073e9SAndroid Build Coastguard Worker        native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True);  new_empty = None
952*da0073e9SAndroid Build Coastguard Worker        getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
953*da0073e9SAndroid Build Coastguard Worker        getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1];  native_dropout = None
954*da0073e9SAndroid Build Coastguard Worker        return (getitem, getitem_1)""",  # noqa: B950
955*da0073e9SAndroid Build Coastguard Worker        )
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker    def test_statically_known_true(self):
958*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
959*da0073e9SAndroid Build Coastguard Worker        s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        # Statically known true
962*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true(True))
963*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true(s2 == s2))
964*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true(s2 * s3 > s3))
965*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true(s3 * s4 > s4))
966*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        # Statically known false
969*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(False))
970*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s3 * s4 <= s4))
971*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker        # True for hints, but not known statically
974*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s2 + s2 == s4))
975*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s4 % s2 == 0))
976*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s2 != s3))
977*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s3 * s4 > s2))
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        # False for hints, but not known statically
980*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s2 == s3))
981*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s2 > s3))
982*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(statically_known_true(s3 + s3 == s4))
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Worker        # No guards should be generated
985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(shape_env.guards), 0)
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    def test_ephemeral_source_simplification(self):
988*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.source import EphemeralSource
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        # For full robustness, ensure the ephemeral source symbols are simplified out regardless
991*da0073e9SAndroid Build Coastguard Worker        # of construction order or check order.
992*da0073e9SAndroid Build Coastguard Worker        for construct_ephemeral_first, x_first_in_check in itertools.product(
993*da0073e9SAndroid Build Coastguard Worker            [False, True], [False, True]
994*da0073e9SAndroid Build Coastguard Worker        ):
995*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
996*da0073e9SAndroid Build Coastguard Worker            shape = (5, 10)
997*da0073e9SAndroid Build Coastguard Worker            dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
998*da0073e9SAndroid Build Coastguard Worker            x = create_symbolic_tensor(
999*da0073e9SAndroid Build Coastguard Worker                "x",
1000*da0073e9SAndroid Build Coastguard Worker                torch.randn(*shape),
1001*da0073e9SAndroid Build Coastguard Worker                shape_env,
1002*da0073e9SAndroid Build Coastguard Worker                source=(EphemeralSource() if construct_ephemeral_first else None),
1003*da0073e9SAndroid Build Coastguard Worker                dynamic_dims=dynamic_dims,
1004*da0073e9SAndroid Build Coastguard Worker            )
1005*da0073e9SAndroid Build Coastguard Worker            y = create_symbolic_tensor(
1006*da0073e9SAndroid Build Coastguard Worker                "y",
1007*da0073e9SAndroid Build Coastguard Worker                torch.randn(*shape),
1008*da0073e9SAndroid Build Coastguard Worker                shape_env,
1009*da0073e9SAndroid Build Coastguard Worker                source=(EphemeralSource() if not construct_ephemeral_first else None),
1010*da0073e9SAndroid Build Coastguard Worker                dynamic_dims=dynamic_dims,
1011*da0073e9SAndroid Build Coastguard Worker            )
1012*da0073e9SAndroid Build Coastguard Worker            t_with_ephemeral = x if construct_ephemeral_first else y
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker            def _get_ephemeral_source_symbols(t):
1015*da0073e9SAndroid Build Coastguard Worker                return [
1016*da0073e9SAndroid Build Coastguard Worker                    s.node.expr
1017*da0073e9SAndroid Build Coastguard Worker                    for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
1018*da0073e9SAndroid Build Coastguard Worker                    if isinstance(s, torch.SymInt)
1019*da0073e9SAndroid Build Coastguard Worker                    and s.node.expr in shape_env.var_to_sources
1020*da0073e9SAndroid Build Coastguard Worker                    and any(
1021*da0073e9SAndroid Build Coastguard Worker                        source.is_ephemeral()
1022*da0073e9SAndroid Build Coastguard Worker                        for source in shape_env.var_to_sources[s.node.expr]
1023*da0073e9SAndroid Build Coastguard Worker                    )
1024*da0073e9SAndroid Build Coastguard Worker                ]
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker            # these checks should simplify out the ephemeral symbols, regardless of the
1027*da0073e9SAndroid Build Coastguard Worker            # ordering x == y or y == x
1028*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
1029*da0073e9SAndroid Build Coastguard Worker            if x_first_in_check:
1030*da0073e9SAndroid Build Coastguard Worker                torch._check(x.size() == y.size())
1031*da0073e9SAndroid Build Coastguard Worker                torch._check(x.stride() == y.stride())
1032*da0073e9SAndroid Build Coastguard Worker                torch._check(x.storage_offset() == y.storage_offset())
1033*da0073e9SAndroid Build Coastguard Worker            else:
1034*da0073e9SAndroid Build Coastguard Worker                torch._check(y.size() == x.size())
1035*da0073e9SAndroid Build Coastguard Worker                torch._check(y.stride() == x.stride())
1036*da0073e9SAndroid Build Coastguard Worker                torch._check(y.storage_offset() == x.storage_offset())
1037*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker    def test_ephemeral_source_unified_with_non_ephemeral_source(self):
1040*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.source import EphemeralSource
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker        for construct_ephemeral_first in (False, True):
1043*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
1044*da0073e9SAndroid Build Coastguard Worker            shape = (5, 10)
1045*da0073e9SAndroid Build Coastguard Worker            # use duck sizing here to ensure symbol reuse across x and y
1046*da0073e9SAndroid Build Coastguard Worker            duck_dims = [DimDynamic.DUCK for _ in shape]
1047*da0073e9SAndroid Build Coastguard Worker            x = create_symbolic_tensor(
1048*da0073e9SAndroid Build Coastguard Worker                "x",
1049*da0073e9SAndroid Build Coastguard Worker                torch.randn(*shape),
1050*da0073e9SAndroid Build Coastguard Worker                shape_env,
1051*da0073e9SAndroid Build Coastguard Worker                source=(EphemeralSource() if construct_ephemeral_first else None),
1052*da0073e9SAndroid Build Coastguard Worker                dynamic_dims=duck_dims,
1053*da0073e9SAndroid Build Coastguard Worker            )
1054*da0073e9SAndroid Build Coastguard Worker            y = create_symbolic_tensor(
1055*da0073e9SAndroid Build Coastguard Worker                "y",
1056*da0073e9SAndroid Build Coastguard Worker                torch.randn(*shape),
1057*da0073e9SAndroid Build Coastguard Worker                shape_env,
1058*da0073e9SAndroid Build Coastguard Worker                source=(EphemeralSource() if not construct_ephemeral_first else None),
1059*da0073e9SAndroid Build Coastguard Worker                dynamic_dims=duck_dims,
1060*da0073e9SAndroid Build Coastguard Worker            )
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker            # regardless of construction order, non-ephemeral sources should be preferred
1063*da0073e9SAndroid Build Coastguard Worker            # first in the var_to_sources list for potential guarding later on
1064*da0073e9SAndroid Build Coastguard Worker            for source_list in shape_env.var_to_sources.values():
1065*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(source_list[0].is_ephemeral())
1066*da0073e9SAndroid Build Coastguard Worker
1067*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.size(), y.size())
1068*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.stride(), y.stride())
1069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.storage_offset(), y.storage_offset())
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo(
1073*da0073e9SAndroid Build Coastguard Worker    "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
1074*da0073e9SAndroid Build Coastguard Worker)
1075*da0073e9SAndroid Build Coastguard Workerclass TestSymNumberMagicMethods(TestCase):
1076*da0073e9SAndroid Build Coastguard Worker    def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
1077*da0073e9SAndroid Build Coastguard Worker        with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
1078*da0073e9SAndroid Build Coastguard Worker            return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
1081*da0073e9SAndroid Build Coastguard Worker        # Helper function
1082*da0073e9SAndroid Build Coastguard Worker        # NB: don't use one as that will get specialized
1083*da0073e9SAndroid Build Coastguard Worker        # TODO: We don't have to circuitously create the float, can just
1084*da0073e9SAndroid Build Coastguard Worker        # create a symfloat directly
1085*da0073e9SAndroid Build Coastguard Worker        seed_node = (create_symint(shape_env, 2) / 2.0).node
1086*da0073e9SAndroid Build Coastguard Worker        bool_seed_node = (create_symint(shape_env, 2) == 2).node
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        def get_sym_inp(inp):
1089*da0073e9SAndroid Build Coastguard Worker            # NB: this must come before int
1090*da0073e9SAndroid Build Coastguard Worker            if isinstance(inp, bool):
1091*da0073e9SAndroid Build Coastguard Worker                return torch.SymBool(to_node(bool_seed_node, inp))
1092*da0073e9SAndroid Build Coastguard Worker            elif isinstance(inp, int):
1093*da0073e9SAndroid Build Coastguard Worker                return torch.SymInt(to_node(seed_node, inp))
1094*da0073e9SAndroid Build Coastguard Worker            else:
1095*da0073e9SAndroid Build Coastguard Worker                return torch.SymFloat(to_node(seed_node, inp))
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker        if fn == "float_pow":
1098*da0073e9SAndroid Build Coastguard Worker            if inp1 < 0:
1099*da0073e9SAndroid Build Coastguard Worker                return
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker        if fn == "pow_by_natural":
1102*da0073e9SAndroid Build Coastguard Worker            if isinstance(inp1, float) or isinstance(inp2, float):
1103*da0073e9SAndroid Build Coastguard Worker                return
1104*da0073e9SAndroid Build Coastguard Worker            if inp2 < 0:
1105*da0073e9SAndroid Build Coastguard Worker                return
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker        def maybe_xfail(inp1, inp2):
1108*da0073e9SAndroid Build Coastguard Worker            if fn == "sym_sqrt" and inp1 < 0:
1109*da0073e9SAndroid Build Coastguard Worker                # ValueError: math domain error
1110*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((ValueError,))
1111*da0073e9SAndroid Build Coastguard Worker            elif (
1112*da0073e9SAndroid Build Coastguard Worker                fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
1113*da0073e9SAndroid Build Coastguard Worker                and inp2 == 0
1114*da0073e9SAndroid Build Coastguard Worker            ):
1115*da0073e9SAndroid Build Coastguard Worker                # ZeroDivisionError: division by zero
1116*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((ZeroDivisionError,))
1117*da0073e9SAndroid Build Coastguard Worker            elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
1118*da0073e9SAndroid Build Coastguard Worker                # ZeroDivisionError: 0.0 cannot be raised to a negative power
1119*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((ZeroDivisionError,))
1120*da0073e9SAndroid Build Coastguard Worker            elif (
1121*da0073e9SAndroid Build Coastguard Worker                # TODO: dear catastrophe waitress,
1122*da0073e9SAndroid Build Coastguard Worker                # this doesn't work
1123*da0073e9SAndroid Build Coastguard Worker                fn in ["float_pow", "pow_by_natural"]
1124*da0073e9SAndroid Build Coastguard Worker                and inp1 < 0
1125*da0073e9SAndroid Build Coastguard Worker                and (
1126*da0073e9SAndroid Build Coastguard Worker                    type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
1127*da0073e9SAndroid Build Coastguard Worker                )
1128*da0073e9SAndroid Build Coastguard Worker                and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
1129*da0073e9SAndroid Build Coastguard Worker            ):
1130*da0073e9SAndroid Build Coastguard Worker                # Complex result, which we do not support:
1131*da0073e9SAndroid Build Coastguard Worker                # TypeError: Cannot convert complex to float
1132*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((RuntimeError,))
1133*da0073e9SAndroid Build Coastguard Worker            elif fn in ("lshift", "rshift") and not (
1134*da0073e9SAndroid Build Coastguard Worker                isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
1135*da0073e9SAndroid Build Coastguard Worker            ):
1136*da0073e9SAndroid Build Coastguard Worker                # TypeError: unsupported operand type(s)
1137*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((TypeError,))
1138*da0073e9SAndroid Build Coastguard Worker            elif fn in ("lshift", "rshift") and inp2 < 0:
1139*da0073e9SAndroid Build Coastguard Worker                # ValueError: math domain error
1140*da0073e9SAndroid Build Coastguard Worker                return self.assertRaises((ValueError,))
1141*da0073e9SAndroid Build Coastguard Worker            else:
1142*da0073e9SAndroid Build Coastguard Worker                return contextlib.nullcontext()
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        lambda_apply = method_to_operator(fn)
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        def guard_fn(v):
1147*da0073e9SAndroid Build Coastguard Worker            if type(v) in (SymBool, bool):
1148*da0073e9SAndroid Build Coastguard Worker                return guard_bool(v)
1149*da0073e9SAndroid Build Coastguard Worker            elif type(v) in (SymFloat, float):
1150*da0073e9SAndroid Build Coastguard Worker                return guard_float(v)
1151*da0073e9SAndroid Build Coastguard Worker            else:  # SymInt, int
1152*da0073e9SAndroid Build Coastguard Worker                return guard_int(v)
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        # Get reference result
1155*da0073e9SAndroid Build Coastguard Worker        with maybe_xfail(inp1, inp2):
1156*da0073e9SAndroid Build Coastguard Worker            if is_unary_fn:
1157*da0073e9SAndroid Build Coastguard Worker                ref_out = lambda_apply(inp1)
1158*da0073e9SAndroid Build Coastguard Worker            else:
1159*da0073e9SAndroid Build Coastguard Worker                ref_out = lambda_apply(inp1, inp2)
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker        # Symified first arg
1162*da0073e9SAndroid Build Coastguard Worker        sym_inp1 = get_sym_inp(inp1)
1163*da0073e9SAndroid Build Coastguard Worker        with maybe_xfail(sym_inp1, inp2):
1164*da0073e9SAndroid Build Coastguard Worker            if is_unary_fn:
1165*da0073e9SAndroid Build Coastguard Worker                out = lambda_apply(sym_inp1)
1166*da0073e9SAndroid Build Coastguard Worker            else:
1167*da0073e9SAndroid Build Coastguard Worker                out = lambda_apply(sym_inp1, inp2)
1168*da0073e9SAndroid Build Coastguard Worker            if fn not in sym_node.alternate_impl_if_hinted_methods:
1169*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1170*da0073e9SAndroid Build Coastguard Worker            out = guard_fn(out)
1171*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker        if is_unary_fn:
1174*da0073e9SAndroid Build Coastguard Worker            return
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        # Symified second arg
1177*da0073e9SAndroid Build Coastguard Worker        sym_inp2 = get_sym_inp(inp2)
1178*da0073e9SAndroid Build Coastguard Worker        with maybe_xfail(inp1, sym_inp2):
1179*da0073e9SAndroid Build Coastguard Worker            out = lambda_apply(inp1, sym_inp2)
1180*da0073e9SAndroid Build Coastguard Worker            if fn not in sym_node.alternate_impl_if_hinted_methods:
1181*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1182*da0073e9SAndroid Build Coastguard Worker            out = guard_fn(out)
1183*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        # Symified both args
1186*da0073e9SAndroid Build Coastguard Worker        with maybe_xfail(sym_inp1, sym_inp2):
1187*da0073e9SAndroid Build Coastguard Worker            out = lambda_apply(sym_inp1, sym_inp2)
1188*da0073e9SAndroid Build Coastguard Worker            if fn not in sym_node.alternate_impl_if_hinted_methods:
1189*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
1190*da0073e9SAndroid Build Coastguard Worker            out = guard_fn(out)
1191*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", list(sym_node.magic_methods.keys()))
1194*da0073e9SAndroid Build Coastguard Worker    def test_bool_method(self, fn):
1195*da0073e9SAndroid Build Coastguard Worker        # sym_ite has its own tests
1196*da0073e9SAndroid Build Coastguard Worker        if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
1197*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} is non-bool")
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker        is_unary_fn = fn in sym_node.unary_methods
1200*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1201*da0073e9SAndroid Build Coastguard Worker        self._do_test(fn, True, False, shape_env, is_unary_fn)
1202*da0073e9SAndroid Build Coastguard Worker
1203*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", list(sym_node.magic_methods.keys()))
1204*da0073e9SAndroid Build Coastguard Worker    @parametrize("first_type", ["int", "float"])
1205*da0073e9SAndroid Build Coastguard Worker    @parametrize("second_type", ["int", "float"])
1206*da0073e9SAndroid Build Coastguard Worker    def test_method(self, fn, first_type, second_type):
1207*da0073e9SAndroid Build Coastguard Worker        if first_type == "float":
1208*da0073e9SAndroid Build Coastguard Worker            # TODO: Hmm, this looks like we skip all floats
1209*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} is not a float magic method")
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        if (
1212*da0073e9SAndroid Build Coastguard Worker            first_type == "int" or second_type == "int"
1213*da0073e9SAndroid Build Coastguard Worker        ) and fn in sym_node.only_float_magic_methods:
1214*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} is not an int method")
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        if second_type == "float" and fn in ["mod"]:
1217*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} only handles int")
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        is_unary_fn = fn in sym_node.unary_methods or fn == "round"
1220*da0073e9SAndroid Build Coastguard Worker        # Second argument is ignored for unary function. So only run for one type
1221*da0073e9SAndroid Build Coastguard Worker        if is_unary_fn and second_type == "float":
1222*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} is unary and already tested")
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        if fn in sym_node.bool_magic_methods:
1225*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{fn} is bool")
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker        # Only floats here since these will be converted to int if necessary.
1228*da0073e9SAndroid Build Coastguard Worker        # We also ignore complex and bool.
1229*da0073e9SAndroid Build Coastguard Worker        values = (
1230*da0073e9SAndroid Build Coastguard Worker            0.0,
1231*da0073e9SAndroid Build Coastguard Worker            1.0,
1232*da0073e9SAndroid Build Coastguard Worker            0.5 if fn in ("sym_acos", "sym_asin") else 2.5,  # avoid math domain error
1233*da0073e9SAndroid Build Coastguard Worker        )
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker        neg_values = tuple(-x for x in values)
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker        for inp1, inp2 in itertools.chain(
1238*da0073e9SAndroid Build Coastguard Worker            itertools.product(values, values),
1239*da0073e9SAndroid Build Coastguard Worker            itertools.product(values, neg_values),
1240*da0073e9SAndroid Build Coastguard Worker            itertools.product(neg_values, values),
1241*da0073e9SAndroid Build Coastguard Worker            itertools.product(neg_values, neg_values),
1242*da0073e9SAndroid Build Coastguard Worker        ):
1243*da0073e9SAndroid Build Coastguard Worker            if first_type == "int":
1244*da0073e9SAndroid Build Coastguard Worker                inp1 = int(inp1)
1245*da0073e9SAndroid Build Coastguard Worker            if second_type == "int":
1246*da0073e9SAndroid Build Coastguard Worker                inp2 = int(inp2)
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker            shape_env = ShapeEnv()
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker            self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker    def get_constant_bool(self, val):
1253*da0073e9SAndroid Build Coastguard Worker        return SymBool(torch._C._get_constant_bool_symnode(val))
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1256*da0073e9SAndroid Build Coastguard Worker    def test_symint_hashing(self):
1257*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1258*da0073e9SAndroid Build Coastguard Worker        hash(create_symint(shape_env, 3))
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker    def test_symnode_hashing(self):
1261*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        # These all trigger specialization when hashed
1264*da0073e9SAndroid Build Coastguard Worker        hash(create_symbool(shape_env, True))
1265*da0073e9SAndroid Build Coastguard Worker        # We should be passing in float here, but create_symbol currently
1266*da0073e9SAndroid Build Coastguard Worker        # only supports int
1267*da0073e9SAndroid Build Coastguard Worker        hash(create_symfloat(shape_env, 3.0))
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker        # NestedInt (SymInt), constant SymBool, SymNode are hashable
1270*da0073e9SAndroid Build Coastguard Worker        j1 = torch._C._get_nested_int(1, 1)
1271*da0073e9SAndroid Build Coastguard Worker        j1_copy = torch._C._get_nested_int(1, 1)
1272*da0073e9SAndroid Build Coastguard Worker        j2 = torch._C._get_nested_int(2, 1)
1273*da0073e9SAndroid Build Coastguard Worker        t = self.get_constant_bool(True)
1274*da0073e9SAndroid Build Coastguard Worker        t_copy = self.get_constant_bool(True)
1275*da0073e9SAndroid Build Coastguard Worker        f = self.get_constant_bool(False)
1276*da0073e9SAndroid Build Coastguard Worker        n = create_symint(shape_env, 3).node
1277*da0073e9SAndroid Build Coastguard Worker        m = self.get_constant_bool(True).node
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 == j1_copy, True)
1280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hash(j1), hash(j1_copy))
1281*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 == j2, False)
1282*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(hash(j1), hash(j2))
1283*da0073e9SAndroid Build Coastguard Worker        self.assertIs(t == t_copy, True)
1284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hash(t), hash(t_copy))
1285*da0073e9SAndroid Build Coastguard Worker        self.assertIs(t == f, False)
1286*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(hash(t), hash(f))
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        hash(n)
1289*da0073e9SAndroid Build Coastguard Worker        hash(m)
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker    def test_symint_deepcopy(self):
1292*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        symnodes = (torch._C._get_nested_int(1, 1),)
1295*da0073e9SAndroid Build Coastguard Worker        deepcopied_symnodes = copy.deepcopy(symnodes)
1296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(symnodes, deepcopied_symnodes)
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def test_non_symbolic_symnode(self):
1299*da0073e9SAndroid Build Coastguard Worker        j1 = torch._C._get_nested_int(1, 1)
1300*da0073e9SAndroid Build Coastguard Worker        j2 = torch._C._get_nested_int(1, 1)
1301*da0073e9SAndroid Build Coastguard Worker        j3 = torch._C._get_nested_int(3, 1)
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(j1, torch.SymInt)
1304*da0073e9SAndroid Build Coastguard Worker        self.assertNotIsInstance(j1, int)
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1307*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "add not supported by NestedIntSymNode"
1308*da0073e9SAndroid Build Coastguard Worker        ):
1309*da0073e9SAndroid Build Coastguard Worker            j1 + 3
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(j1 == 3)
1312*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "indeterminate"):
1313*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(3 >= j2)
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 == j1, True)
1316*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 == j2, True)
1317*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 == j3, False)
1318*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 != j3, True)
1319*da0073e9SAndroid Build Coastguard Worker        self.assertIs(j1 != j2, False)
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        x = self.get_constant_bool(True)
1322*da0073e9SAndroid Build Coastguard Worker        #
1323*da0073e9SAndroid Build Coastguard Worker        # Unary
1324*da0073e9SAndroid Build Coastguard Worker        #
1325*da0073e9SAndroid Build Coastguard Worker        # op(constant SymBool)
1326*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x.__sym_not__(), False)
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        #
1329*da0073e9SAndroid Build Coastguard Worker        # Binary
1330*da0073e9SAndroid Build Coastguard Worker        #
1331*da0073e9SAndroid Build Coastguard Worker        # op(constant SymBool, bool)
1332*da0073e9SAndroid Build Coastguard Worker        # op(constant SymBool, constant SymBool)
1333*da0073e9SAndroid Build Coastguard Worker        # op(bool, constant SymBool)
1334*da0073e9SAndroid Build Coastguard Worker        self.assertIs(operator.and_(x, True), True)
1335*da0073e9SAndroid Build Coastguard Worker        self.assertIs(operator.and_(x, x), True)
1336*da0073e9SAndroid Build Coastguard Worker        self.assertIs(operator.and_(True, x), True)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker        # op(symbolic SymBool, constant Symbool)
1339*da0073e9SAndroid Build Coastguard Worker        # op(constant SymBool, symbolic Symbool)
1340*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1341*da0073e9SAndroid Build Coastguard Worker        a = create_symint(shape_env, 2)
1342*da0073e9SAndroid Build Coastguard Worker        b = create_symint(shape_env, 2)
1343*da0073e9SAndroid Build Coastguard Worker        c = a == b  # symbolic SymBool
1344*da0073e9SAndroid Build Coastguard Worker        d = self.get_constant_bool(True)
1345*da0073e9SAndroid Build Coastguard Worker        e = operator.and_(c, d)
1346*da0073e9SAndroid Build Coastguard Worker        f = operator.and_(d, c)
1347*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_symbolic(e))
1348*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_symbolic(f))
1349*da0073e9SAndroid Build Coastguard Worker        self.assertIs(e.node.guard_bool("", 0), True)
1350*da0073e9SAndroid Build Coastguard Worker        self.assertIs(f.node.guard_bool("", 0), True)
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker        # Comparing sizes
1353*da0073e9SAndroid Build Coastguard Worker        sz1 = torch.Size([j1, j1, j1])
1354*da0073e9SAndroid Build Coastguard Worker        sz2 = torch.Size([j1, j1, j1])
1355*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sz1 == sz2, True)
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker        sz1 = torch.Size([3, j1, 4])
1358*da0073e9SAndroid Build Coastguard Worker        sz2 = torch.Size([3, j2, 4])
1359*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sz1 == sz2, True)
1360*da0073e9SAndroid Build Coastguard Worker        self.assertIs(sz1 != sz2, False)
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker    def test_stride_symnode(self):
1363*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.fake_tensor import FakeTensorMode
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker        def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides):
1368*da0073e9SAndroid Build Coastguard Worker            with FakeTensorMode(shape_env=shape_env) as fake_mode:
1369*da0073e9SAndroid Build Coastguard Worker                return fake_mode.from_tensor(
1370*da0073e9SAndroid Build Coastguard Worker                    x,
1371*da0073e9SAndroid Build Coastguard Worker                    symbolic_context=StatelessSymbolicContext(
1372*da0073e9SAndroid Build Coastguard Worker                        dynamic_sizes=dynamic_sizes,
1373*da0073e9SAndroid Build Coastguard Worker                        dynamic_strides=dynamic_strides,
1374*da0073e9SAndroid Build Coastguard Worker                    ),
1375*da0073e9SAndroid Build Coastguard Worker                )
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker        # check everything static
1378*da0073e9SAndroid Build Coastguard Worker        t = _create_symbolic_tensor(
1379*da0073e9SAndroid Build Coastguard Worker            x=torch.ones(3, 6),
1380*da0073e9SAndroid Build Coastguard Worker            dynamic_sizes=[
1381*da0073e9SAndroid Build Coastguard Worker                DimDynamic.STATIC,
1382*da0073e9SAndroid Build Coastguard Worker                DimDynamic.STATIC,
1383*da0073e9SAndroid Build Coastguard Worker            ],
1384*da0073e9SAndroid Build Coastguard Worker            dynamic_strides=[
1385*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1386*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1387*da0073e9SAndroid Build Coastguard Worker            ],
1388*da0073e9SAndroid Build Coastguard Worker        )
1389*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(isinstance(size, int) for size in t.size()))
1390*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker        # check dynamic size but static dims
1393*da0073e9SAndroid Build Coastguard Worker        t = _create_symbolic_tensor(
1394*da0073e9SAndroid Build Coastguard Worker            x=torch.ones(3, 6),
1395*da0073e9SAndroid Build Coastguard Worker            dynamic_sizes=[
1396*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1397*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1398*da0073e9SAndroid Build Coastguard Worker            ],
1399*da0073e9SAndroid Build Coastguard Worker            dynamic_strides=[
1400*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1401*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1402*da0073e9SAndroid Build Coastguard Worker            ],
1403*da0073e9SAndroid Build Coastguard Worker        )
1404*da0073e9SAndroid Build Coastguard Worker        # Expect stride to be inferred
1405*da0073e9SAndroid Build Coastguard Worker        s0, s1 = t.size()
1406*da0073e9SAndroid Build Coastguard Worker        s2, s3 = t.stride()
1407*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s0, torch.SymInt))
1408*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1, torch.SymInt))
1409*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s2, torch.SymInt))
1410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s1 == s2)
1411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s3, 1)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        # Check dynamic stride but static dims
1414*da0073e9SAndroid Build Coastguard Worker        t = _create_symbolic_tensor(
1415*da0073e9SAndroid Build Coastguard Worker            x=torch.ones(3, 6),
1416*da0073e9SAndroid Build Coastguard Worker            dynamic_sizes=[
1417*da0073e9SAndroid Build Coastguard Worker                DimDynamic.STATIC,
1418*da0073e9SAndroid Build Coastguard Worker                DimDynamic.STATIC,
1419*da0073e9SAndroid Build Coastguard Worker            ],
1420*da0073e9SAndroid Build Coastguard Worker            dynamic_strides=[
1421*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1422*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1423*da0073e9SAndroid Build Coastguard Worker            ],
1424*da0073e9SAndroid Build Coastguard Worker        )
1425*da0073e9SAndroid Build Coastguard Worker        s0, s1 = t.size()
1426*da0073e9SAndroid Build Coastguard Worker        s2, s3 = t.stride()
1427*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s0, int))
1428*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1, int))
1429*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s2, torch.SymInt))
1430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s3, int))
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        # Check dynamic sizes and dims, and ensure different symbol
1433*da0073e9SAndroid Build Coastguard Worker        t = _create_symbolic_tensor(
1434*da0073e9SAndroid Build Coastguard Worker            x=torch.ones(3, 6),
1435*da0073e9SAndroid Build Coastguard Worker            dynamic_sizes=[
1436*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1437*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1438*da0073e9SAndroid Build Coastguard Worker            ],
1439*da0073e9SAndroid Build Coastguard Worker            dynamic_strides=[
1440*da0073e9SAndroid Build Coastguard Worker                DimDynamic.DYNAMIC,
1441*da0073e9SAndroid Build Coastguard Worker                DimDynamic.INFER_STRIDE,
1442*da0073e9SAndroid Build Coastguard Worker            ],
1443*da0073e9SAndroid Build Coastguard Worker        )
1444*da0073e9SAndroid Build Coastguard Worker        s0, s1 = t.size()
1445*da0073e9SAndroid Build Coastguard Worker        s2, s3 = t.stride()
1446*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s0, torch.SymInt))
1447*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1, torch.SymInt))
1448*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s2, torch.SymInt))
1449*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s3, int))
1450*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSymNumberMagicMethods)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Workerclass TestFloorDiv(TestCase):
1457*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1458*da0073e9SAndroid Build Coastguard Worker    def python_floordiv(x, y):
1459*da0073e9SAndroid Build Coastguard Worker        return x // y
1460*da0073e9SAndroid Build Coastguard Worker
1461*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1462*da0073e9SAndroid Build Coastguard Worker    def torch_floordiv(x, y):
1463*da0073e9SAndroid Build Coastguard Worker        # Note: we fully evaluate here since FloorDiv might not always do
1464*da0073e9SAndroid Build Coastguard Worker        # that.
1465*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1466*da0073e9SAndroid Build Coastguard Worker        return shape_env.evaluate_expr(FloorDiv(x, y))
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1469*da0073e9SAndroid Build Coastguard Worker    def yield_test_cases(values, negate=True):
1470*da0073e9SAndroid Build Coastguard Worker        for x, y in values:
1471*da0073e9SAndroid Build Coastguard Worker            yield (x, y)
1472*da0073e9SAndroid Build Coastguard Worker            if negate:
1473*da0073e9SAndroid Build Coastguard Worker                yield (-x, y)
1474*da0073e9SAndroid Build Coastguard Worker                yield (x, -y)
1475*da0073e9SAndroid Build Coastguard Worker                yield (-x, -y)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker    def test_floordiv_float_int(self):
1478*da0073e9SAndroid Build Coastguard Worker        values = ((7, 2),)
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        for x, y in TestFloorDiv.yield_test_cases(values):
1481*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1482*da0073e9SAndroid Build Coastguard Worker                TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
1483*da0073e9SAndroid Build Coastguard Worker            )
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker    def test_floordiv_div_by_one(self):
1486*da0073e9SAndroid Build Coastguard Worker        values = ((2, 1),)
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker        for x, y in TestFloorDiv.yield_test_cases(values):
1489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1490*da0073e9SAndroid Build Coastguard Worker                TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
1491*da0073e9SAndroid Build Coastguard Worker            )
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker    def test_floordiv_simplify(self):
1494*da0073e9SAndroid Build Coastguard Worker        # Tests how we simplify or evaluate FloorDiv without free variables
1495*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
1496*da0073e9SAndroid Build Coastguard Worker        result = 21
1497*da0073e9SAndroid Build Coastguard Worker        exprs = (7 * FloorDiv(6, 2),)
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        for expr in exprs:
1500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expr, result)
1501*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expr.doit(deep=False), result)
1502*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expr.doit(deep=True), result)
1503*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sympy.simplify(expr), result)
1504*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape_env.simplify(expr), result)
1505*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape_env.evaluate_expr(expr), result)
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker    def test_floordiv_assumptions(self):
1508*da0073e9SAndroid Build Coastguard Worker        cases = (
1509*da0073e9SAndroid Build Coastguard Worker            sympy.Symbol("i1", integer=True),
1510*da0073e9SAndroid Build Coastguard Worker            sympy.Symbol("i2", integer=True),
1511*da0073e9SAndroid Build Coastguard Worker        )
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker        for base, divisor in itertools.product(cases, repeat=2):
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker            def op():
1516*da0073e9SAndroid Build Coastguard Worker                return FloorDiv(base, divisor)
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker            def is_complex(x):
1519*da0073e9SAndroid Build Coastguard Worker                return x.is_integer is False and x.is_real is False and x.is_complex
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker            if is_complex(base) or is_complex(divisor):
1522*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
1523*da0073e9SAndroid Build Coastguard Worker                    TypeError,
1524*da0073e9SAndroid Build Coastguard Worker                    (
1525*da0073e9SAndroid Build Coastguard Worker                        r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
1526*da0073e9SAndroid Build Coastguard Worker                        r" expected integer or real"
1527*da0073e9SAndroid Build Coastguard Worker                    ),
1528*da0073e9SAndroid Build Coastguard Worker                    op,
1529*da0073e9SAndroid Build Coastguard Worker                )
1530*da0073e9SAndroid Build Coastguard Worker                continue
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker            op = op()
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker            # In regular Python, x//x == 1.0 if x is a float, but FloorDiv
1535*da0073e9SAndroid Build Coastguard Worker            # always returns an integer 1 when both args are the same object.
1536*da0073e9SAndroid Build Coastguard Worker            # This even works for Symbols with no assumptions specified.
1537*da0073e9SAndroid Build Coastguard Worker            if base is divisor:
1538*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op.is_integer)
1539*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op.is_real)
1540*da0073e9SAndroid Build Coastguard Worker            elif base.is_integer and divisor.is_integer:
1541*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op.is_integer)
1542*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op.is_real)
1543*da0073e9SAndroid Build Coastguard Worker            else:
1544*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op.is_integer, None)
1545*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op.is_real)
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Workerclass TestDimConstraints(TestCase):
1549*da0073e9SAndroid Build Coastguard Worker    def test_dim_constraints_reduce_congruences_simple(self):
1550*da0073e9SAndroid Build Coastguard Worker        from sympy import Symbol
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker        s = Symbol("s", positive=True, integer=True)
1553*da0073e9SAndroid Build Coastguard Worker        dim_constraints = DimConstraints({}, {}, set(), {})
1554*da0073e9SAndroid Build Coastguard Worker        dim_constraints._congruences[s] = {
1555*da0073e9SAndroid Build Coastguard Worker            (s / 2) % 2,
1556*da0073e9SAndroid Build Coastguard Worker            (s / 2) % 8,
1557*da0073e9SAndroid Build Coastguard Worker            (s / 2) % 4,
1558*da0073e9SAndroid Build Coastguard Worker            s % 2,
1559*da0073e9SAndroid Build Coastguard Worker            ((s / 16) + 2) % 4,
1560*da0073e9SAndroid Build Coastguard Worker        }
1561*da0073e9SAndroid Build Coastguard Worker        congruences = dim_constraints._reduce_congruences()
1562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(congruences[s], {(s + 32) % 64})
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker    def test_dim_constraints_reduce_inequalities_simple(self):
1565*da0073e9SAndroid Build Coastguard Worker        from sympy import Eq, Interval, Ne, Symbol
1566*da0073e9SAndroid Build Coastguard Worker        from sympy.solvers.inequalities import reduce_inequalities
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker        s = Symbol("s", positive=True, integer=True)
1569*da0073e9SAndroid Build Coastguard Worker        exprs = {
1570*da0073e9SAndroid Build Coastguard Worker            s >= 2,
1571*da0073e9SAndroid Build Coastguard Worker            Ne(8 * s, 16),
1572*da0073e9SAndroid Build Coastguard Worker            Ne(s / 2, 1),
1573*da0073e9SAndroid Build Coastguard Worker            Ne(16 * s, 32),
1574*da0073e9SAndroid Build Coastguard Worker            s < 16,
1575*da0073e9SAndroid Build Coastguard Worker            Ne(s, 2),
1576*da0073e9SAndroid Build Coastguard Worker            s / 2 < 16,
1577*da0073e9SAndroid Build Coastguard Worker            s / 2 > 1,
1578*da0073e9SAndroid Build Coastguard Worker            s / 2 >= 2,
1579*da0073e9SAndroid Build Coastguard Worker            Ne(3 * s / 2, 3),
1580*da0073e9SAndroid Build Coastguard Worker        }
1581*da0073e9SAndroid Build Coastguard Worker        solution = reduce_inequalities(exprs, s).as_set()
1582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(solution, Interval.Ropen(4, 16))
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker        exprs.add(Eq(s / 2, 4))
1585*da0073e9SAndroid Build Coastguard Worker        solution = reduce_inequalities(exprs, s).as_set()
1586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(solution, {8})
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker    def test_dim_constraints_reduce_inequalities_error(self):
1589*da0073e9SAndroid Build Coastguard Worker        from collections import defaultdict
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker        from sympy import Symbol
1592*da0073e9SAndroid Build Coastguard Worker        from sympy.solvers.inequalities import reduce_inequalities
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.source import (
1595*da0073e9SAndroid Build Coastguard Worker            LocalSource,
1596*da0073e9SAndroid Build Coastguard Worker            TensorProperty,
1597*da0073e9SAndroid Build Coastguard Worker            TensorPropertySource,
1598*da0073e9SAndroid Build Coastguard Worker        )
1599*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker        s0 = Symbol("s0", positive=True, integer=True)
1602*da0073e9SAndroid Build Coastguard Worker        exprs = {
1603*da0073e9SAndroid Build Coastguard Worker            4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
1604*da0073e9SAndroid Build Coastguard Worker            s0 >= 2,
1605*da0073e9SAndroid Build Coastguard Worker            s0**3 <= 2147483647,
1606*da0073e9SAndroid Build Coastguard Worker            s0 <= 2147483647,
1607*da0073e9SAndroid Build Coastguard Worker        }
1608*da0073e9SAndroid Build Coastguard Worker        answer = reduce_inequalities(exprs, s0)
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Worker        symbol_to_source = defaultdict(list)
1611*da0073e9SAndroid Build Coastguard Worker        symbol_to_source[s0].append(
1612*da0073e9SAndroid Build Coastguard Worker            TensorPropertySource(
1613*da0073e9SAndroid Build Coastguard Worker                base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
1614*da0073e9SAndroid Build Coastguard Worker            )
1615*da0073e9SAndroid Build Coastguard Worker        )
1616*da0073e9SAndroid Build Coastguard Worker        dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
1617*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1618*da0073e9SAndroid Build Coastguard Worker            AssertionError,
1619*da0073e9SAndroid Build Coastguard Worker            "Unknown symbol.*created by constraints solver",
1620*da0073e9SAndroid Build Coastguard Worker        ):
1621*da0073e9SAndroid Build Coastguard Worker            dcp.doprint(answer)
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker    def test_dim_constraints_solve_full(self):
1624*da0073e9SAndroid Build Coastguard Worker        from sympy import Eq, Integer, Ne, Symbol
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.source import (
1627*da0073e9SAndroid Build Coastguard Worker            LocalSource,
1628*da0073e9SAndroid Build Coastguard Worker            TensorProperty,
1629*da0073e9SAndroid Build Coastguard Worker            TensorPropertySource,
1630*da0073e9SAndroid Build Coastguard Worker        )
1631*da0073e9SAndroid Build Coastguard Worker
1632*da0073e9SAndroid Build Coastguard Worker        src0 = TensorPropertySource(
1633*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
1634*da0073e9SAndroid Build Coastguard Worker        )
1635*da0073e9SAndroid Build Coastguard Worker        src2 = TensorPropertySource(
1636*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
1637*da0073e9SAndroid Build Coastguard Worker        )
1638*da0073e9SAndroid Build Coastguard Worker        src3 = TensorPropertySource(
1639*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
1640*da0073e9SAndroid Build Coastguard Worker        )
1641*da0073e9SAndroid Build Coastguard Worker        src4 = TensorPropertySource(
1642*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
1643*da0073e9SAndroid Build Coastguard Worker        )
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker        src1 = TensorPropertySource(
1646*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
1647*da0073e9SAndroid Build Coastguard Worker        )
1648*da0073e9SAndroid Build Coastguard Worker        src7 = TensorPropertySource(
1649*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
1650*da0073e9SAndroid Build Coastguard Worker        )
1651*da0073e9SAndroid Build Coastguard Worker
1652*da0073e9SAndroid Build Coastguard Worker        src5 = TensorPropertySource(
1653*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
1654*da0073e9SAndroid Build Coastguard Worker        )
1655*da0073e9SAndroid Build Coastguard Worker        src8 = TensorPropertySource(
1656*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
1657*da0073e9SAndroid Build Coastguard Worker        )
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        src6 = TensorPropertySource(
1660*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
1661*da0073e9SAndroid Build Coastguard Worker        )
1662*da0073e9SAndroid Build Coastguard Worker        src9 = TensorPropertySource(
1663*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
1664*da0073e9SAndroid Build Coastguard Worker        )
1665*da0073e9SAndroid Build Coastguard Worker        src10 = TensorPropertySource(
1666*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
1667*da0073e9SAndroid Build Coastguard Worker        )
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker        src11 = TensorPropertySource(
1670*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
1671*da0073e9SAndroid Build Coastguard Worker        )
1672*da0073e9SAndroid Build Coastguard Worker        src12 = TensorPropertySource(
1673*da0073e9SAndroid Build Coastguard Worker            base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
1674*da0073e9SAndroid Build Coastguard Worker        )
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        s0 = Symbol("s0", positive=True, integer=True)
1677*da0073e9SAndroid Build Coastguard Worker        s1 = Symbol("s1", positive=True, integer=True)
1678*da0073e9SAndroid Build Coastguard Worker        s5 = Symbol("s5", positive=True, integer=True)
1679*da0073e9SAndroid Build Coastguard Worker        s6 = Symbol("s6", positive=True, integer=True)
1680*da0073e9SAndroid Build Coastguard Worker        symbol_to_source = {
1681*da0073e9SAndroid Build Coastguard Worker            s0: [src0, src2, src3, src4],
1682*da0073e9SAndroid Build Coastguard Worker            s1: [src1, src7],
1683*da0073e9SAndroid Build Coastguard Worker            s5: [src5, src8],
1684*da0073e9SAndroid Build Coastguard Worker            s6: [src6, src9, src10],
1685*da0073e9SAndroid Build Coastguard Worker        }
1686*da0073e9SAndroid Build Coastguard Worker        var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
1687*da0073e9SAndroid Build Coastguard Worker        marked_dynamic = {s0, s1, s5, s6}
1688*da0073e9SAndroid Build Coastguard Worker        dim_constraints = DimConstraints(
1689*da0073e9SAndroid Build Coastguard Worker            symbol_to_source, var_to_val, marked_dynamic, {}
1690*da0073e9SAndroid Build Coastguard Worker        )
1691*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src2, s0)
1692*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src3, s0)
1693*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src4, s0)
1694*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src7, s1)
1695*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src8, s5)
1696*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src9, s6)
1697*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src10, s6)
1698*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src11, Integer(1))
1699*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add_equality(src12, Integer(3))
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s1**2 <= 2147483647)
1702*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(32 * s1**2 <= 2147483647)
1703*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s0 < 16)
1704*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(Mod(s1, 2), 0))
1705*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
1706*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
1707*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
1708*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
1709*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
1710*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1711*da0073e9SAndroid Build Coastguard Worker            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1712*da0073e9SAndroid Build Coastguard Worker            + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1713*da0073e9SAndroid Build Coastguard Worker            + 64
1714*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
1715*da0073e9SAndroid Build Coastguard Worker        )
1716*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
1717*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1718*da0073e9SAndroid Build Coastguard Worker            Ne(
1719*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1720*da0073e9SAndroid Build Coastguard Worker                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1721*da0073e9SAndroid Build Coastguard Worker                + 1,
1722*da0073e9SAndroid Build Coastguard Worker                1,
1723*da0073e9SAndroid Build Coastguard Worker            )
1724*da0073e9SAndroid Build Coastguard Worker        )
1725*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
1726*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1727*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
1728*da0073e9SAndroid Build Coastguard Worker            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
1729*da0073e9SAndroid Build Coastguard Worker            + 1
1730*da0073e9SAndroid Build Coastguard Worker            > 1
1731*da0073e9SAndroid Build Coastguard Worker        )
1732*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1733*da0073e9SAndroid Build Coastguard Worker            128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1734*da0073e9SAndroid Build Coastguard Worker            + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1735*da0073e9SAndroid Build Coastguard Worker            + 128
1736*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
1737*da0073e9SAndroid Build Coastguard Worker        )
1738*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
1739*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1740*da0073e9SAndroid Build Coastguard Worker            Ne(
1741*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1742*da0073e9SAndroid Build Coastguard Worker                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1743*da0073e9SAndroid Build Coastguard Worker                + 1,
1744*da0073e9SAndroid Build Coastguard Worker                1,
1745*da0073e9SAndroid Build Coastguard Worker            )
1746*da0073e9SAndroid Build Coastguard Worker        )
1747*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
1748*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1749*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
1750*da0073e9SAndroid Build Coastguard Worker            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
1751*da0073e9SAndroid Build Coastguard Worker            + 1
1752*da0073e9SAndroid Build Coastguard Worker            > 1
1753*da0073e9SAndroid Build Coastguard Worker        )
1754*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1755*da0073e9SAndroid Build Coastguard Worker            256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1756*da0073e9SAndroid Build Coastguard Worker            + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1757*da0073e9SAndroid Build Coastguard Worker            + 256
1758*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
1759*da0073e9SAndroid Build Coastguard Worker        )
1760*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
1761*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1762*da0073e9SAndroid Build Coastguard Worker            Ne(
1763*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1764*da0073e9SAndroid Build Coastguard Worker                + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1765*da0073e9SAndroid Build Coastguard Worker                + 1,
1766*da0073e9SAndroid Build Coastguard Worker                1,
1767*da0073e9SAndroid Build Coastguard Worker            )
1768*da0073e9SAndroid Build Coastguard Worker        )
1769*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
1770*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1771*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1772*da0073e9SAndroid Build Coastguard Worker            + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1773*da0073e9SAndroid Build Coastguard Worker            + 1
1774*da0073e9SAndroid Build Coastguard Worker            > 1
1775*da0073e9SAndroid Build Coastguard Worker        )
1776*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
1777*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1778*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1779*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1780*da0073e9SAndroid Build Coastguard Worker            + 60
1781*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
1782*da0073e9SAndroid Build Coastguard Worker        )
1783*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
1784*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
1785*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1786*da0073e9SAndroid Build Coastguard Worker            Ne(
1787*da0073e9SAndroid Build Coastguard Worker                60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1788*da0073e9SAndroid Build Coastguard Worker                - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1789*da0073e9SAndroid Build Coastguard Worker                + 60 * s0,
1790*da0073e9SAndroid Build Coastguard Worker                0,
1791*da0073e9SAndroid Build Coastguard Worker            )
1792*da0073e9SAndroid Build Coastguard Worker        )
1793*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
1794*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1795*da0073e9SAndroid Build Coastguard Worker            Ne(
1796*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1797*da0073e9SAndroid Build Coastguard Worker                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1798*da0073e9SAndroid Build Coastguard Worker                + 1,
1799*da0073e9SAndroid Build Coastguard Worker                1,
1800*da0073e9SAndroid Build Coastguard Worker            )
1801*da0073e9SAndroid Build Coastguard Worker        )
1802*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1803*da0073e9SAndroid Build Coastguard Worker            Ne(
1804*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1805*da0073e9SAndroid Build Coastguard Worker                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1806*da0073e9SAndroid Build Coastguard Worker                + 1,
1807*da0073e9SAndroid Build Coastguard Worker                0,
1808*da0073e9SAndroid Build Coastguard Worker            )
1809*da0073e9SAndroid Build Coastguard Worker        )
1810*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1811*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1812*da0073e9SAndroid Build Coastguard Worker            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1813*da0073e9SAndroid Build Coastguard Worker            + 1
1814*da0073e9SAndroid Build Coastguard Worker            >= 0
1815*da0073e9SAndroid Build Coastguard Worker        )
1816*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
1817*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1818*da0073e9SAndroid Build Coastguard Worker            1
1819*da0073e9SAndroid Build Coastguard Worker            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1820*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1821*da0073e9SAndroid Build Coastguard Worker            + 60
1822*da0073e9SAndroid Build Coastguard Worker        )
1823*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
1824*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1825*da0073e9SAndroid Build Coastguard Worker            Ne(
1826*da0073e9SAndroid Build Coastguard Worker                60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1827*da0073e9SAndroid Build Coastguard Worker                - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1828*da0073e9SAndroid Build Coastguard Worker                + 60 * s0,
1829*da0073e9SAndroid Build Coastguard Worker                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1830*da0073e9SAndroid Build Coastguard Worker                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1831*da0073e9SAndroid Build Coastguard Worker                + 120,
1832*da0073e9SAndroid Build Coastguard Worker            )
1833*da0073e9SAndroid Build Coastguard Worker        )
1834*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1835*da0073e9SAndroid Build Coastguard Worker            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1836*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1837*da0073e9SAndroid Build Coastguard Worker            + 120
1838*da0073e9SAndroid Build Coastguard Worker            > 0
1839*da0073e9SAndroid Build Coastguard Worker        )
1840*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1841*da0073e9SAndroid Build Coastguard Worker            Eq(
1842*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
1843*da0073e9SAndroid Build Coastguard Worker                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
1844*da0073e9SAndroid Build Coastguard Worker                + 60 * (Mod(s0, 2)),
1845*da0073e9SAndroid Build Coastguard Worker                0,
1846*da0073e9SAndroid Build Coastguard Worker            )
1847*da0073e9SAndroid Build Coastguard Worker        )
1848*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1849*da0073e9SAndroid Build Coastguard Worker            Ne(
1850*da0073e9SAndroid Build Coastguard Worker                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1851*da0073e9SAndroid Build Coastguard Worker                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1852*da0073e9SAndroid Build Coastguard Worker                + 120,
1853*da0073e9SAndroid Build Coastguard Worker                0,
1854*da0073e9SAndroid Build Coastguard Worker            )
1855*da0073e9SAndroid Build Coastguard Worker        )
1856*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1857*da0073e9SAndroid Build Coastguard Worker            Ne(
1858*da0073e9SAndroid Build Coastguard Worker                60
1859*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv(s0, 2))
1860*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv(s0, (FloorDiv(s0, 2))))
1861*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1862*da0073e9SAndroid Build Coastguard Worker                - 120
1863*da0073e9SAndroid Build Coastguard Worker                * FloorDiv(s0, 2)
1864*da0073e9SAndroid Build Coastguard Worker                * FloorDiv(s0, (FloorDiv(s0, 2)))
1865*da0073e9SAndroid Build Coastguard Worker                * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1866*da0073e9SAndroid Build Coastguard Worker                + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1867*da0073e9SAndroid Build Coastguard Worker                0,
1868*da0073e9SAndroid Build Coastguard Worker            )
1869*da0073e9SAndroid Build Coastguard Worker        )
1870*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
1871*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1872*da0073e9SAndroid Build Coastguard Worker            Ne(
1873*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1874*da0073e9SAndroid Build Coastguard Worker                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1875*da0073e9SAndroid Build Coastguard Worker                + 60,
1876*da0073e9SAndroid Build Coastguard Worker                0,
1877*da0073e9SAndroid Build Coastguard Worker            )
1878*da0073e9SAndroid Build Coastguard Worker        )
1879*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1880*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1881*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1882*da0073e9SAndroid Build Coastguard Worker            + 60
1883*da0073e9SAndroid Build Coastguard Worker            >= 0
1884*da0073e9SAndroid Build Coastguard Worker        )
1885*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1886*da0073e9SAndroid Build Coastguard Worker            1
1887*da0073e9SAndroid Build Coastguard Worker            < 60
1888*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1889*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1890*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1891*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
1892*da0073e9SAndroid Build Coastguard Worker        )
1893*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(16 * s0, 32))
1894*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
1895*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(16 * s0, 32))
1896*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
1897*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv(s0, 2) >= 2)
1898*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
1899*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(1 < FloorDiv(s0, 2))
1900*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(s0, 2))
1901*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1902*da0073e9SAndroid Build Coastguard Worker            60
1903*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1904*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1905*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1906*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
1907*da0073e9SAndroid Build Coastguard Worker            >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1908*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1909*da0073e9SAndroid Build Coastguard Worker            + 60
1910*da0073e9SAndroid Build Coastguard Worker        )
1911*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1912*da0073e9SAndroid Build Coastguard Worker            60
1913*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, 2))
1914*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, (FloorDiv(s0, 2))))
1915*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1916*da0073e9SAndroid Build Coastguard Worker            - 120
1917*da0073e9SAndroid Build Coastguard Worker            * FloorDiv(s0, 2)
1918*da0073e9SAndroid Build Coastguard Worker            * FloorDiv(s0, (FloorDiv(s0, 2)))
1919*da0073e9SAndroid Build Coastguard Worker            * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1920*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
1921*da0073e9SAndroid Build Coastguard Worker            > 0
1922*da0073e9SAndroid Build Coastguard Worker        )
1923*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1924*da0073e9SAndroid Build Coastguard Worker            Ne(
1925*da0073e9SAndroid Build Coastguard Worker                60
1926*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv(s0, 2))
1927*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv(s0, (FloorDiv(s0, 2))))
1928*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1929*da0073e9SAndroid Build Coastguard Worker                - 120
1930*da0073e9SAndroid Build Coastguard Worker                * FloorDiv(s0, 2)
1931*da0073e9SAndroid Build Coastguard Worker                * FloorDiv(s0, (FloorDiv(s0, 2)))
1932*da0073e9SAndroid Build Coastguard Worker                * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1933*da0073e9SAndroid Build Coastguard Worker                + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1934*da0073e9SAndroid Build Coastguard Worker                3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
1935*da0073e9SAndroid Build Coastguard Worker            )
1936*da0073e9SAndroid Build Coastguard Worker        )
1937*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1938*da0073e9SAndroid Build Coastguard Worker            Ne(
1939*da0073e9SAndroid Build Coastguard Worker                20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1940*da0073e9SAndroid Build Coastguard Worker                - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1941*da0073e9SAndroid Build Coastguard Worker                + 20,
1942*da0073e9SAndroid Build Coastguard Worker                0,
1943*da0073e9SAndroid Build Coastguard Worker            )
1944*da0073e9SAndroid Build Coastguard Worker        )
1945*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1946*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1947*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1948*da0073e9SAndroid Build Coastguard Worker            + 20
1949*da0073e9SAndroid Build Coastguard Worker            >= 0
1950*da0073e9SAndroid Build Coastguard Worker        )
1951*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1952*da0073e9SAndroid Build Coastguard Worker            Ne(
1953*da0073e9SAndroid Build Coastguard Worker                20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1954*da0073e9SAndroid Build Coastguard Worker                - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1955*da0073e9SAndroid Build Coastguard Worker                + 20,
1956*da0073e9SAndroid Build Coastguard Worker                20,
1957*da0073e9SAndroid Build Coastguard Worker            )
1958*da0073e9SAndroid Build Coastguard Worker        )
1959*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1960*da0073e9SAndroid Build Coastguard Worker            Ne(
1961*da0073e9SAndroid Build Coastguard Worker                20
1962*da0073e9SAndroid Build Coastguard Worker                * (
1963*da0073e9SAndroid Build Coastguard Worker                    Mod(
1964*da0073e9SAndroid Build Coastguard Worker                        1,
1965*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1966*da0073e9SAndroid Build Coastguard Worker                        - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1967*da0073e9SAndroid Build Coastguard Worker                        + 1,
1968*da0073e9SAndroid Build Coastguard Worker                    )
1969*da0073e9SAndroid Build Coastguard Worker                ),
1970*da0073e9SAndroid Build Coastguard Worker                0,
1971*da0073e9SAndroid Build Coastguard Worker            )
1972*da0073e9SAndroid Build Coastguard Worker        )
1973*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
1974*da0073e9SAndroid Build Coastguard Worker            Ne(
1975*da0073e9SAndroid Build Coastguard Worker                20
1976*da0073e9SAndroid Build Coastguard Worker                * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
1977*da0073e9SAndroid Build Coastguard Worker                * (
1978*da0073e9SAndroid Build Coastguard Worker                    Mod(
1979*da0073e9SAndroid Build Coastguard Worker                        1,
1980*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1981*da0073e9SAndroid Build Coastguard Worker                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1982*da0073e9SAndroid Build Coastguard Worker                        - 2
1983*da0073e9SAndroid Build Coastguard Worker                        * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1984*da0073e9SAndroid Build Coastguard Worker                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1985*da0073e9SAndroid Build Coastguard Worker                        + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1986*da0073e9SAndroid Build Coastguard Worker                    )
1987*da0073e9SAndroid Build Coastguard Worker                )
1988*da0073e9SAndroid Build Coastguard Worker                - 20
1989*da0073e9SAndroid Build Coastguard Worker                * Mod(
1990*da0073e9SAndroid Build Coastguard Worker                    1,
1991*da0073e9SAndroid Build Coastguard Worker                    (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
1992*da0073e9SAndroid Build Coastguard Worker                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1993*da0073e9SAndroid Build Coastguard Worker                    - 2
1994*da0073e9SAndroid Build Coastguard Worker                    * FloorDiv((FloorDiv(s1, 2) - 1), 8)
1995*da0073e9SAndroid Build Coastguard Worker                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
1996*da0073e9SAndroid Build Coastguard Worker                    + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1997*da0073e9SAndroid Build Coastguard Worker                ),
1998*da0073e9SAndroid Build Coastguard Worker                0,
1999*da0073e9SAndroid Build Coastguard Worker            )
2000*da0073e9SAndroid Build Coastguard Worker        )
2001*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
2002*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2003*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2004*da0073e9SAndroid Build Coastguard Worker            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2005*da0073e9SAndroid Build Coastguard Worker            + 1
2006*da0073e9SAndroid Build Coastguard Worker            >= 1
2007*da0073e9SAndroid Build Coastguard Worker        )
2008*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2009*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2010*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2011*da0073e9SAndroid Build Coastguard Worker            + 20
2012*da0073e9SAndroid Build Coastguard Worker            >= 0
2013*da0073e9SAndroid Build Coastguard Worker        )
2014*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2015*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2016*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2017*da0073e9SAndroid Build Coastguard Worker            + 20
2018*da0073e9SAndroid Build Coastguard Worker            >= 1
2019*da0073e9SAndroid Build Coastguard Worker        )
2020*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2021*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2022*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2023*da0073e9SAndroid Build Coastguard Worker            + 20
2024*da0073e9SAndroid Build Coastguard Worker            >= 2
2025*da0073e9SAndroid Build Coastguard Worker        )
2026*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2027*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2028*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2029*da0073e9SAndroid Build Coastguard Worker            + 20
2030*da0073e9SAndroid Build Coastguard Worker            > 1
2031*da0073e9SAndroid Build Coastguard Worker        )
2032*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2033*da0073e9SAndroid Build Coastguard Worker            20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2034*da0073e9SAndroid Build Coastguard Worker            - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2035*da0073e9SAndroid Build Coastguard Worker            + 20
2036*da0073e9SAndroid Build Coastguard Worker            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2037*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2038*da0073e9SAndroid Build Coastguard Worker            + 60
2039*da0073e9SAndroid Build Coastguard Worker        )
2040*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2041*da0073e9SAndroid Build Coastguard Worker            Ne(
2042*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2043*da0073e9SAndroid Build Coastguard Worker                - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2044*da0073e9SAndroid Build Coastguard Worker                + 60,
2045*da0073e9SAndroid Build Coastguard Worker                60,
2046*da0073e9SAndroid Build Coastguard Worker            )
2047*da0073e9SAndroid Build Coastguard Worker        )
2048*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2049*da0073e9SAndroid Build Coastguard Worker            Ne(
2050*da0073e9SAndroid Build Coastguard Worker                FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
2051*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2052*da0073e9SAndroid Build Coastguard Worker                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2053*da0073e9SAndroid Build Coastguard Worker                + 1,
2054*da0073e9SAndroid Build Coastguard Worker            )
2055*da0073e9SAndroid Build Coastguard Worker        )
2056*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2057*da0073e9SAndroid Build Coastguard Worker            Eq(
2058*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8))
2059*da0073e9SAndroid Build Coastguard Worker                * (
2060*da0073e9SAndroid Build Coastguard Worker                    Mod(
2061*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2062*da0073e9SAndroid Build Coastguard Worker                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2063*da0073e9SAndroid Build Coastguard Worker                        - 2
2064*da0073e9SAndroid Build Coastguard Worker                        * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2065*da0073e9SAndroid Build Coastguard Worker                        / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2066*da0073e9SAndroid Build Coastguard Worker                        + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
2067*da0073e9SAndroid Build Coastguard Worker                        1,
2068*da0073e9SAndroid Build Coastguard Worker                    )
2069*da0073e9SAndroid Build Coastguard Worker                )
2070*da0073e9SAndroid Build Coastguard Worker                - Mod(
2071*da0073e9SAndroid Build Coastguard Worker                    (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2072*da0073e9SAndroid Build Coastguard Worker                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2073*da0073e9SAndroid Build Coastguard Worker                    - 2
2074*da0073e9SAndroid Build Coastguard Worker                    * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2075*da0073e9SAndroid Build Coastguard Worker                    / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
2076*da0073e9SAndroid Build Coastguard Worker                    + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
2077*da0073e9SAndroid Build Coastguard Worker                    1,
2078*da0073e9SAndroid Build Coastguard Worker                ),
2079*da0073e9SAndroid Build Coastguard Worker                0,
2080*da0073e9SAndroid Build Coastguard Worker            )
2081*da0073e9SAndroid Build Coastguard Worker        )
2082*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2083*da0073e9SAndroid Build Coastguard Worker            Ne(
2084*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2085*da0073e9SAndroid Build Coastguard Worker                - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2086*da0073e9SAndroid Build Coastguard Worker                + 1,
2087*da0073e9SAndroid Build Coastguard Worker                FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
2088*da0073e9SAndroid Build Coastguard Worker            )
2089*da0073e9SAndroid Build Coastguard Worker        )
2090*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(8 * s0, 16))
2091*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2092*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2093*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2094*da0073e9SAndroid Build Coastguard Worker            + 60
2095*da0073e9SAndroid Build Coastguard Worker            >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2096*da0073e9SAndroid Build Coastguard Worker            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2097*da0073e9SAndroid Build Coastguard Worker            + 1
2098*da0073e9SAndroid Build Coastguard Worker        )
2099*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2100*da0073e9SAndroid Build Coastguard Worker            60
2101*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, (FloorDiv(s0, 2))))
2102*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2103*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2104*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
2105*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2106*da0073e9SAndroid Build Coastguard Worker        )
2107*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2108*da0073e9SAndroid Build Coastguard Worker            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2109*da0073e9SAndroid Build Coastguard Worker            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2110*da0073e9SAndroid Build Coastguard Worker            + 90
2111*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2112*da0073e9SAndroid Build Coastguard Worker        )
2113*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv(s0, 2) < 16)
2114*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv(s0, 2) > 1)
2115*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2116*da0073e9SAndroid Build Coastguard Worker            Ne(
2117*da0073e9SAndroid Build Coastguard Worker                90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2118*da0073e9SAndroid Build Coastguard Worker                - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2119*da0073e9SAndroid Build Coastguard Worker                + 90 * (FloorDiv(s0, 2)),
2120*da0073e9SAndroid Build Coastguard Worker                0,
2121*da0073e9SAndroid Build Coastguard Worker            )
2122*da0073e9SAndroid Build Coastguard Worker        )
2123*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2124*da0073e9SAndroid Build Coastguard Worker            1
2125*da0073e9SAndroid Build Coastguard Worker            < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2126*da0073e9SAndroid Build Coastguard Worker            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2127*da0073e9SAndroid Build Coastguard Worker            + 90
2128*da0073e9SAndroid Build Coastguard Worker        )
2129*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2130*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2131*da0073e9SAndroid Build Coastguard Worker            - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2132*da0073e9SAndroid Build Coastguard Worker            + 1
2133*da0073e9SAndroid Build Coastguard Worker            > 1
2134*da0073e9SAndroid Build Coastguard Worker        )
2135*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2136*da0073e9SAndroid Build Coastguard Worker            60
2137*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv(s0, (FloorDiv(s0, 2))))
2138*da0073e9SAndroid Build Coastguard Worker            * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2139*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2140*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
2141*da0073e9SAndroid Build Coastguard Worker            > 1
2142*da0073e9SAndroid Build Coastguard Worker        )
2143*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2144*da0073e9SAndroid Build Coastguard Worker            Ne(
2145*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2146*da0073e9SAndroid Build Coastguard Worker                - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2147*da0073e9SAndroid Build Coastguard Worker                + 60 * (FloorDiv(s0, 2)),
2148*da0073e9SAndroid Build Coastguard Worker                0,
2149*da0073e9SAndroid Build Coastguard Worker            )
2150*da0073e9SAndroid Build Coastguard Worker        )
2151*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2152*da0073e9SAndroid Build Coastguard Worker            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2153*da0073e9SAndroid Build Coastguard Worker            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2154*da0073e9SAndroid Build Coastguard Worker            + 90
2155*da0073e9SAndroid Build Coastguard Worker            > 1
2156*da0073e9SAndroid Build Coastguard Worker        )
2157*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2158*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2159*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2160*da0073e9SAndroid Build Coastguard Worker            + 60
2161*da0073e9SAndroid Build Coastguard Worker            > 1
2162*da0073e9SAndroid Build Coastguard Worker        )
2163*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2164*da0073e9SAndroid Build Coastguard Worker            Ne(
2165*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2166*da0073e9SAndroid Build Coastguard Worker                - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2167*da0073e9SAndroid Build Coastguard Worker                + 60 * (FloorDiv(s0, 2)),
2168*da0073e9SAndroid Build Coastguard Worker                3 * (FloorDiv(s0, 2)),
2169*da0073e9SAndroid Build Coastguard Worker            )
2170*da0073e9SAndroid Build Coastguard Worker        )
2171*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2172*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2173*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2174*da0073e9SAndroid Build Coastguard Worker            + 60 * (FloorDiv(s0, 2))
2175*da0073e9SAndroid Build Coastguard Worker            > 0
2176*da0073e9SAndroid Build Coastguard Worker        )
2177*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2178*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2179*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2180*da0073e9SAndroid Build Coastguard Worker            + 60
2181*da0073e9SAndroid Build Coastguard Worker            > 0
2182*da0073e9SAndroid Build Coastguard Worker        )
2183*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2184*da0073e9SAndroid Build Coastguard Worker            Ne(
2185*da0073e9SAndroid Build Coastguard Worker                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2186*da0073e9SAndroid Build Coastguard Worker                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2187*da0073e9SAndroid Build Coastguard Worker                + 120,
2188*da0073e9SAndroid Build Coastguard Worker                0,
2189*da0073e9SAndroid Build Coastguard Worker            )
2190*da0073e9SAndroid Build Coastguard Worker        )
2191*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2192*da0073e9SAndroid Build Coastguard Worker            1
2193*da0073e9SAndroid Build Coastguard Worker            < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2194*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2195*da0073e9SAndroid Build Coastguard Worker            + 120
2196*da0073e9SAndroid Build Coastguard Worker        )
2197*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2198*da0073e9SAndroid Build Coastguard Worker            Ne(
2199*da0073e9SAndroid Build Coastguard Worker                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2200*da0073e9SAndroid Build Coastguard Worker                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2201*da0073e9SAndroid Build Coastguard Worker                + 120,
2202*da0073e9SAndroid Build Coastguard Worker                6,
2203*da0073e9SAndroid Build Coastguard Worker            )
2204*da0073e9SAndroid Build Coastguard Worker        )
2205*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2206*da0073e9SAndroid Build Coastguard Worker            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2207*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2208*da0073e9SAndroid Build Coastguard Worker            + 120
2209*da0073e9SAndroid Build Coastguard Worker            > 0
2210*da0073e9SAndroid Build Coastguard Worker        )
2211*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2212*da0073e9SAndroid Build Coastguard Worker            Ne(
2213*da0073e9SAndroid Build Coastguard Worker                120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2214*da0073e9SAndroid Build Coastguard Worker                - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2215*da0073e9SAndroid Build Coastguard Worker                + 120,
2216*da0073e9SAndroid Build Coastguard Worker                0,
2217*da0073e9SAndroid Build Coastguard Worker            )
2218*da0073e9SAndroid Build Coastguard Worker        )
2219*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2220*da0073e9SAndroid Build Coastguard Worker            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2221*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2222*da0073e9SAndroid Build Coastguard Worker            + 120
2223*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2224*da0073e9SAndroid Build Coastguard Worker        )
2225*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2226*da0073e9SAndroid Build Coastguard Worker            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2227*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2228*da0073e9SAndroid Build Coastguard Worker            + 120
2229*da0073e9SAndroid Build Coastguard Worker            <= 20480
2230*da0073e9SAndroid Build Coastguard Worker        )
2231*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2232*da0073e9SAndroid Build Coastguard Worker            Ne(
2233*da0073e9SAndroid Build Coastguard Worker                90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2234*da0073e9SAndroid Build Coastguard Worker                - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2235*da0073e9SAndroid Build Coastguard Worker                + 90,
2236*da0073e9SAndroid Build Coastguard Worker                0,
2237*da0073e9SAndroid Build Coastguard Worker            )
2238*da0073e9SAndroid Build Coastguard Worker        )
2239*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2240*da0073e9SAndroid Build Coastguard Worker            120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2241*da0073e9SAndroid Build Coastguard Worker            - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2242*da0073e9SAndroid Build Coastguard Worker            + 120
2243*da0073e9SAndroid Build Coastguard Worker            > 1
2244*da0073e9SAndroid Build Coastguard Worker        )
2245*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2246*da0073e9SAndroid Build Coastguard Worker            90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2247*da0073e9SAndroid Build Coastguard Worker            - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2248*da0073e9SAndroid Build Coastguard Worker            + 90
2249*da0073e9SAndroid Build Coastguard Worker            <= 20480
2250*da0073e9SAndroid Build Coastguard Worker        )
2251*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2252*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2253*da0073e9SAndroid Build Coastguard Worker            - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2254*da0073e9SAndroid Build Coastguard Worker            + 60
2255*da0073e9SAndroid Build Coastguard Worker            <= 20480
2256*da0073e9SAndroid Build Coastguard Worker        )
2257*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2258*da0073e9SAndroid Build Coastguard Worker            Ne(
2259*da0073e9SAndroid Build Coastguard Worker                240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2260*da0073e9SAndroid Build Coastguard Worker                - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2261*da0073e9SAndroid Build Coastguard Worker                + 240,
2262*da0073e9SAndroid Build Coastguard Worker                0,
2263*da0073e9SAndroid Build Coastguard Worker            )
2264*da0073e9SAndroid Build Coastguard Worker        )
2265*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(6 * s5, 132))
2266*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
2267*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
2268*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2269*da0073e9SAndroid Build Coastguard Worker            Ne(
2270*da0073e9SAndroid Build Coastguard Worker                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2271*da0073e9SAndroid Build Coastguard Worker                - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2272*da0073e9SAndroid Build Coastguard Worker                + 64 * (FloorDiv(s0, 2)),
2273*da0073e9SAndroid Build Coastguard Worker                0,
2274*da0073e9SAndroid Build Coastguard Worker            )
2275*da0073e9SAndroid Build Coastguard Worker        )
2276*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2277*da0073e9SAndroid Build Coastguard Worker            1
2278*da0073e9SAndroid Build Coastguard Worker            < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2279*da0073e9SAndroid Build Coastguard Worker            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2280*da0073e9SAndroid Build Coastguard Worker            + 64
2281*da0073e9SAndroid Build Coastguard Worker        )
2282*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2283*da0073e9SAndroid Build Coastguard Worker            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2284*da0073e9SAndroid Build Coastguard Worker            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2285*da0073e9SAndroid Build Coastguard Worker            + 64
2286*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2287*da0073e9SAndroid Build Coastguard Worker        )
2288*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2289*da0073e9SAndroid Build Coastguard Worker            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2290*da0073e9SAndroid Build Coastguard Worker            - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2291*da0073e9SAndroid Build Coastguard Worker            + 64
2292*da0073e9SAndroid Build Coastguard Worker            > 1
2293*da0073e9SAndroid Build Coastguard Worker        )
2294*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2295*da0073e9SAndroid Build Coastguard Worker            62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2296*da0073e9SAndroid Build Coastguard Worker            - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2297*da0073e9SAndroid Build Coastguard Worker            + 62
2298*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2299*da0073e9SAndroid Build Coastguard Worker        )
2300*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2301*da0073e9SAndroid Build Coastguard Worker            Ne(
2302*da0073e9SAndroid Build Coastguard Worker                62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2303*da0073e9SAndroid Build Coastguard Worker                - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2304*da0073e9SAndroid Build Coastguard Worker                + 62 * (FloorDiv(s0, 2)),
2305*da0073e9SAndroid Build Coastguard Worker                0,
2306*da0073e9SAndroid Build Coastguard Worker            )
2307*da0073e9SAndroid Build Coastguard Worker        )
2308*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2309*da0073e9SAndroid Build Coastguard Worker            1
2310*da0073e9SAndroid Build Coastguard Worker            < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2311*da0073e9SAndroid Build Coastguard Worker            - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2312*da0073e9SAndroid Build Coastguard Worker            + 62
2313*da0073e9SAndroid Build Coastguard Worker        )
2314*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
2315*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
2316*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
2317*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
2318*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
2319*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
2320*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2321*da0073e9SAndroid Build Coastguard Worker            64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2322*da0073e9SAndroid Build Coastguard Worker            - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2323*da0073e9SAndroid Build Coastguard Worker            + 576
2324*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2325*da0073e9SAndroid Build Coastguard Worker        )
2326*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
2327*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
2328*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2329*da0073e9SAndroid Build Coastguard Worker            Ne(
2330*da0073e9SAndroid Build Coastguard Worker                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2331*da0073e9SAndroid Build Coastguard Worker                - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2332*da0073e9SAndroid Build Coastguard Worker                + 576 * (FloorDiv(s0, 2)),
2333*da0073e9SAndroid Build Coastguard Worker                0,
2334*da0073e9SAndroid Build Coastguard Worker            )
2335*da0073e9SAndroid Build Coastguard Worker        )
2336*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
2337*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2338*da0073e9SAndroid Build Coastguard Worker            Ne(
2339*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2340*da0073e9SAndroid Build Coastguard Worker                - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2341*da0073e9SAndroid Build Coastguard Worker                + 9,
2342*da0073e9SAndroid Build Coastguard Worker                1,
2343*da0073e9SAndroid Build Coastguard Worker            )
2344*da0073e9SAndroid Build Coastguard Worker        )
2345*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2346*da0073e9SAndroid Build Coastguard Worker            Ne(
2347*da0073e9SAndroid Build Coastguard Worker                (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2348*da0073e9SAndroid Build Coastguard Worker                - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2349*da0073e9SAndroid Build Coastguard Worker                + 9,
2350*da0073e9SAndroid Build Coastguard Worker                0,
2351*da0073e9SAndroid Build Coastguard Worker            )
2352*da0073e9SAndroid Build Coastguard Worker        )
2353*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2354*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2355*da0073e9SAndroid Build Coastguard Worker            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2356*da0073e9SAndroid Build Coastguard Worker            + 9
2357*da0073e9SAndroid Build Coastguard Worker            >= 0
2358*da0073e9SAndroid Build Coastguard Worker        )
2359*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
2360*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2361*da0073e9SAndroid Build Coastguard Worker            1
2362*da0073e9SAndroid Build Coastguard Worker            < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2363*da0073e9SAndroid Build Coastguard Worker            - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2364*da0073e9SAndroid Build Coastguard Worker            + 576
2365*da0073e9SAndroid Build Coastguard Worker        )
2366*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
2367*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2368*da0073e9SAndroid Build Coastguard Worker            Ne(
2369*da0073e9SAndroid Build Coastguard Worker                64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2370*da0073e9SAndroid Build Coastguard Worker                - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2371*da0073e9SAndroid Build Coastguard Worker                + 576 * (FloorDiv(s0, 2)),
2372*da0073e9SAndroid Build Coastguard Worker                256,
2373*da0073e9SAndroid Build Coastguard Worker            )
2374*da0073e9SAndroid Build Coastguard Worker        )
2375*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2376*da0073e9SAndroid Build Coastguard Worker            Eq(
2377*da0073e9SAndroid Build Coastguard Worker                64
2378*da0073e9SAndroid Build Coastguard Worker                * (
2379*da0073e9SAndroid Build Coastguard Worker                    Mod(
2380*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2381*da0073e9SAndroid Build Coastguard Worker                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2382*da0073e9SAndroid Build Coastguard Worker                        + 9 * (FloorDiv(s0, 2)),
2383*da0073e9SAndroid Build Coastguard Worker                        4,
2384*da0073e9SAndroid Build Coastguard Worker                    )
2385*da0073e9SAndroid Build Coastguard Worker                ),
2386*da0073e9SAndroid Build Coastguard Worker                0,
2387*da0073e9SAndroid Build Coastguard Worker            )
2388*da0073e9SAndroid Build Coastguard Worker        )
2389*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2390*da0073e9SAndroid Build Coastguard Worker            Eq(
2391*da0073e9SAndroid Build Coastguard Worker                FloorDiv(s0, 2),
2392*da0073e9SAndroid Build Coastguard Worker                FloorDiv(
2393*da0073e9SAndroid Build Coastguard Worker                    (
2394*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2395*da0073e9SAndroid Build Coastguard Worker                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2396*da0073e9SAndroid Build Coastguard Worker                        + 9 * (FloorDiv(s0, 2))
2397*da0073e9SAndroid Build Coastguard Worker                    ),
2398*da0073e9SAndroid Build Coastguard Worker                    4,
2399*da0073e9SAndroid Build Coastguard Worker                ),
2400*da0073e9SAndroid Build Coastguard Worker            )
2401*da0073e9SAndroid Build Coastguard Worker        )
2402*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2403*da0073e9SAndroid Build Coastguard Worker            Eq(
2404*da0073e9SAndroid Build Coastguard Worker                FloorDiv(
2405*da0073e9SAndroid Build Coastguard Worker                    (
2406*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2407*da0073e9SAndroid Build Coastguard Worker                        - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2408*da0073e9SAndroid Build Coastguard Worker                        + 9 * (FloorDiv(s0, 2))
2409*da0073e9SAndroid Build Coastguard Worker                    ),
2410*da0073e9SAndroid Build Coastguard Worker                    4,
2411*da0073e9SAndroid Build Coastguard Worker                ),
2412*da0073e9SAndroid Build Coastguard Worker                FloorDiv(s0, 2),
2413*da0073e9SAndroid Build Coastguard Worker            )
2414*da0073e9SAndroid Build Coastguard Worker        )
2415*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2416*da0073e9SAndroid Build Coastguard Worker            Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
2417*da0073e9SAndroid Build Coastguard Worker        )
2418*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2419*da0073e9SAndroid Build Coastguard Worker            Eq(
2420*da0073e9SAndroid Build Coastguard Worker                64
2421*da0073e9SAndroid Build Coastguard Worker                * (
2422*da0073e9SAndroid Build Coastguard Worker                    Mod(
2423*da0073e9SAndroid Build Coastguard Worker                        (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2424*da0073e9SAndroid Build Coastguard Worker                        - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2425*da0073e9SAndroid Build Coastguard Worker                        + 1,
2426*da0073e9SAndroid Build Coastguard Worker                        4,
2427*da0073e9SAndroid Build Coastguard Worker                    )
2428*da0073e9SAndroid Build Coastguard Worker                ),
2429*da0073e9SAndroid Build Coastguard Worker                0,
2430*da0073e9SAndroid Build Coastguard Worker            )
2431*da0073e9SAndroid Build Coastguard Worker        )
2432*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2433*da0073e9SAndroid Build Coastguard Worker            64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2434*da0073e9SAndroid Build Coastguard Worker            - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2435*da0073e9SAndroid Build Coastguard Worker            + 576 * (FloorDiv(s0, 2))
2436*da0073e9SAndroid Build Coastguard Worker            > 0
2437*da0073e9SAndroid Build Coastguard Worker        )
2438*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2439*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2440*da0073e9SAndroid Build Coastguard Worker            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2441*da0073e9SAndroid Build Coastguard Worker            + 9
2442*da0073e9SAndroid Build Coastguard Worker            >= 1
2443*da0073e9SAndroid Build Coastguard Worker        )
2444*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2445*da0073e9SAndroid Build Coastguard Worker            Eq(
2446*da0073e9SAndroid Build Coastguard Worker                64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2447*da0073e9SAndroid Build Coastguard Worker                - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2448*da0073e9SAndroid Build Coastguard Worker                + 576,
2449*da0073e9SAndroid Build Coastguard Worker                256,
2450*da0073e9SAndroid Build Coastguard Worker            )
2451*da0073e9SAndroid Build Coastguard Worker        )
2452*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2453*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2454*da0073e9SAndroid Build Coastguard Worker            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2455*da0073e9SAndroid Build Coastguard Worker            + 540
2456*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2457*da0073e9SAndroid Build Coastguard Worker        )
2458*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2459*da0073e9SAndroid Build Coastguard Worker            Ne(
2460*da0073e9SAndroid Build Coastguard Worker                60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2461*da0073e9SAndroid Build Coastguard Worker                - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2462*da0073e9SAndroid Build Coastguard Worker                + 540 * (FloorDiv(s0, 2)),
2463*da0073e9SAndroid Build Coastguard Worker                0,
2464*da0073e9SAndroid Build Coastguard Worker            )
2465*da0073e9SAndroid Build Coastguard Worker        )
2466*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2467*da0073e9SAndroid Build Coastguard Worker            1
2468*da0073e9SAndroid Build Coastguard Worker            < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2469*da0073e9SAndroid Build Coastguard Worker            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2470*da0073e9SAndroid Build Coastguard Worker            + 540
2471*da0073e9SAndroid Build Coastguard Worker        )
2472*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2473*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2474*da0073e9SAndroid Build Coastguard Worker            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2475*da0073e9SAndroid Build Coastguard Worker            + 9
2476*da0073e9SAndroid Build Coastguard Worker            <= 2147483647
2477*da0073e9SAndroid Build Coastguard Worker        )
2478*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2479*da0073e9SAndroid Build Coastguard Worker            Ne(
2480*da0073e9SAndroid Build Coastguard Worker                (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2481*da0073e9SAndroid Build Coastguard Worker                - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2482*da0073e9SAndroid Build Coastguard Worker                + 9 * (FloorDiv(s0, 2)),
2483*da0073e9SAndroid Build Coastguard Worker                0,
2484*da0073e9SAndroid Build Coastguard Worker            )
2485*da0073e9SAndroid Build Coastguard Worker        )
2486*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2487*da0073e9SAndroid Build Coastguard Worker            1
2488*da0073e9SAndroid Build Coastguard Worker            < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2489*da0073e9SAndroid Build Coastguard Worker            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2490*da0073e9SAndroid Build Coastguard Worker            + 9
2491*da0073e9SAndroid Build Coastguard Worker        )
2492*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2493*da0073e9SAndroid Build Coastguard Worker            (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2494*da0073e9SAndroid Build Coastguard Worker            - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2495*da0073e9SAndroid Build Coastguard Worker            + 9
2496*da0073e9SAndroid Build Coastguard Worker            > 1
2497*da0073e9SAndroid Build Coastguard Worker        )
2498*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(
2499*da0073e9SAndroid Build Coastguard Worker            60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
2500*da0073e9SAndroid Build Coastguard Worker            - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
2501*da0073e9SAndroid Build Coastguard Worker            + 540
2502*da0073e9SAndroid Build Coastguard Worker            > 1
2503*da0073e9SAndroid Build Coastguard Worker        )
2504*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s0 >= 2)
2505*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s1 >= 2)
2506*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s6 >= 2)
2507*da0073e9SAndroid Build Coastguard Worker        dim_constraints.add(s5 >= 2)
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker        dim_constraints.solve()
2510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2511*da0073e9SAndroid Build Coastguard Worker            dim_constraints._static_results,
2512*da0073e9SAndroid Build Coastguard Worker            {
2513*da0073e9SAndroid Build Coastguard Worker                "L['c'].size()[0] == 8",
2514*da0073e9SAndroid Build Coastguard Worker                "L['d'].size()[0] == 8",
2515*da0073e9SAndroid Build Coastguard Worker                "L['a'].size()[2] == 96",
2516*da0073e9SAndroid Build Coastguard Worker                "L['f'].size()[1] == 1",
2517*da0073e9SAndroid Build Coastguard Worker                "L['a'].size()[3] == 96",
2518*da0073e9SAndroid Build Coastguard Worker                "L['b'].size()[2] == 3",
2519*da0073e9SAndroid Build Coastguard Worker                "L['b'].size()[1] == 22",
2520*da0073e9SAndroid Build Coastguard Worker                "L['b'].size()[0] == 8",
2521*da0073e9SAndroid Build Coastguard Worker                "L['a'].size()[1] == 22",
2522*da0073e9SAndroid Build Coastguard Worker                "L['a'].size()[0] == 8",
2523*da0073e9SAndroid Build Coastguard Worker            },
2524*da0073e9SAndroid Build Coastguard Worker        )
2525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2526*da0073e9SAndroid Build Coastguard Worker            dim_constraints._dynamic_results,
2527*da0073e9SAndroid Build Coastguard Worker            {
2528*da0073e9SAndroid Build Coastguard Worker                "2 <= L['c'].size()[1]",
2529*da0073e9SAndroid Build Coastguard Worker                "L['d'].size()[1] == L['c'].size()[1]",
2530*da0073e9SAndroid Build Coastguard Worker                "L['e'].size()[1] == L['c'].size()[1]",
2531*da0073e9SAndroid Build Coastguard Worker            },
2532*da0073e9SAndroid Build Coastguard Worker        )
2533*da0073e9SAndroid Build Coastguard Worker
2534*da0073e9SAndroid Build Coastguard Worker
2535*da0073e9SAndroid Build Coastguard Workerclass TestGuardsExpressions(TestCase):
2536*da0073e9SAndroid Build Coastguard Worker    """
2537*da0073e9SAndroid Build Coastguard Worker    Tests the guards-related methods used by the inductor FX graph cache.
2538*da0073e9SAndroid Build Coastguard Worker    """
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker    def test_guards_gt_lt(self):
2541*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
2542*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 6)
2543*da0073e9SAndroid Build Coastguard Worker        s1 = create_symint(shape_env, 7)
2544*da0073e9SAndroid Build Coastguard Worker        s2 = create_symint(shape_env, 5)
2545*da0073e9SAndroid Build Coastguard Worker
2546*da0073e9SAndroid Build Coastguard Worker        guard_int(sym_int(s0 > 5))
2547*da0073e9SAndroid Build Coastguard Worker        guard_int(sym_int(s0 < 7))
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker        guards = shape_env.produce_guards_expression([s0])
2550*da0073e9SAndroid Build Coastguard Worker
2551*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2552*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
2553*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
2554*da0073e9SAndroid Build Coastguard Worker
2555*da0073e9SAndroid Build Coastguard Worker    def test_guards_float_print(self):
2556*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
2557*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 3)
2558*da0073e9SAndroid Build Coastguard Worker        guard_bool(2 / s0 == 2 / 3)
2559*da0073e9SAndroid Build Coastguard Worker        guards = shape_env.produce_guards_expression([s0])
2560*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2561*da0073e9SAndroid Build Coastguard Worker
2562*da0073e9SAndroid Build Coastguard Worker    def test_guards_float_div(self):
2563*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
2564*da0073e9SAndroid Build Coastguard Worker        s0 = create_symint(shape_env, 8)
2565*da0073e9SAndroid Build Coastguard Worker        s1 = create_symint(shape_env, 7)
2566*da0073e9SAndroid Build Coastguard Worker
2567*da0073e9SAndroid Build Coastguard Worker        guard_int(sym_int(s0 / 2.0))
2568*da0073e9SAndroid Build Coastguard Worker        guards = shape_env.produce_guards_expression([s0])
2569*da0073e9SAndroid Build Coastguard Worker
2570*da0073e9SAndroid Build Coastguard Worker        self.assertIn("ToFloat", guards)
2571*da0073e9SAndroid Build Coastguard Worker        self.assertIn("FloatTrueDiv", guards)
2572*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
2573*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
2574*da0073e9SAndroid Build Coastguard Worker
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2577*da0073e9SAndroid Build Coastguard Worker    run_tests()
2578