1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport copy 5*da0073e9SAndroid Build Coastguard Workerimport itertools 6*da0073e9SAndroid Build Coastguard Workerimport math 7*da0073e9SAndroid Build Coastguard Workerimport operator 8*da0073e9SAndroid Build Coastguard Workerimport unittest 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport numpy as np 11*da0073e9SAndroid Build Coastguard Workerimport sympy 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerimport torch.fx 15*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 16*da0073e9SAndroid Build Coastguard Workerfrom torch import sym_int, SymBool, SymFloat, SymInt 17*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _disabled_torch_function_impl 18*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental import sym_node 19*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx 20*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node 21*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ( 22*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size, 23*da0073e9SAndroid Build Coastguard Worker DimConstraints, 24*da0073e9SAndroid Build Coastguard Worker DimDynamic, 25*da0073e9SAndroid Build Coastguard Worker expect_true, 26*da0073e9SAndroid Build Coastguard Worker guard_bool, 27*da0073e9SAndroid Build Coastguard Worker guard_float, 28*da0073e9SAndroid Build Coastguard Worker guard_int, 29*da0073e9SAndroid Build Coastguard Worker GuardOnDataDependentSymNode, 30*da0073e9SAndroid Build Coastguard Worker hint_int, 31*da0073e9SAndroid Build Coastguard Worker is_symbolic, 32*da0073e9SAndroid Build Coastguard Worker ShapeEnv, 33*da0073e9SAndroid Build Coastguard Worker StatelessSymbolicContext, 34*da0073e9SAndroid Build Coastguard Worker statically_known_true, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 37*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 38*da0073e9SAndroid Build Coastguard Worker parametrize, 39*da0073e9SAndroid Build Coastguard Worker run_tests, 40*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 41*da0073e9SAndroid Build Coastguard Worker TestCase, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 45*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._sympy.functions import ( 46*da0073e9SAndroid Build Coastguard Worker FloorDiv, 47*da0073e9SAndroid Build Coastguard Worker IsNonOverlappingAndDenseIndicator, 48*da0073e9SAndroid Build Coastguard Worker Mod, 49*da0073e9SAndroid Build Coastguard Worker) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Workermeta_funcs = {} 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workerdef register_meta(op): 58*da0073e9SAndroid Build Coastguard Worker def decorator(f): 59*da0073e9SAndroid Build Coastguard Worker def add_func(op): 60*da0073e9SAndroid Build Coastguard Worker meta_funcs[op] = f 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker pytree.tree_map_(add_func, op) 63*da0073e9SAndroid Build Coastguard Worker return f 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker return decorator 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.add.Tensor, aten.sub.Tensor]) 69*da0073e9SAndroid Build Coastguard Workerdef binary_meta(a, b): 70*da0073e9SAndroid Build Coastguard Worker return a.new_empty(a.shape) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker@register_meta(aten.cat.default) 74*da0073e9SAndroid Build Coastguard Workerdef cat_meta(tensors, dim=0): 75*da0073e9SAndroid Build Coastguard Worker concat_length = 0 76*da0073e9SAndroid Build Coastguard Worker shape = tensors[0].shape 77*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 78*da0073e9SAndroid Build Coastguard Worker for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): 79*da0073e9SAndroid Build Coastguard Worker if idx == dim: 80*da0073e9SAndroid Build Coastguard Worker concat_length = concat_length + length 81*da0073e9SAndroid Build Coastguard Worker else: 82*da0073e9SAndroid Build Coastguard Worker assert length == common_length 83*da0073e9SAndroid Build Coastguard Worker new_shape = list(shape) 84*da0073e9SAndroid Build Coastguard Worker new_shape[dim] = concat_length 85*da0073e9SAndroid Build Coastguard Worker return tensors[0].new_empty(new_shape) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.narrow_copy.default]) 89*da0073e9SAndroid Build Coastguard Workerdef narrow_copy_symint_meta(a, dim, start, length, **kwargs): 90*da0073e9SAndroid Build Coastguard Worker shape = [] 91*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(a.shape): 92*da0073e9SAndroid Build Coastguard Worker if i == dim: 93*da0073e9SAndroid Build Coastguard Worker shape.append(length) 94*da0073e9SAndroid Build Coastguard Worker else: 95*da0073e9SAndroid Build Coastguard Worker shape.append(x) 96*da0073e9SAndroid Build Coastguard Worker return a.new_empty(tuple(shape)) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker@register_meta([aten.expand.default]) 100*da0073e9SAndroid Build Coastguard Workerdef expand_symint_meta(a, size, implicit=False): 101*da0073e9SAndroid Build Coastguard Worker return a.new_empty(size) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef create_contiguous(shape): 105*da0073e9SAndroid Build Coastguard Worker strides = [1] 106*da0073e9SAndroid Build Coastguard Worker for dim in reversed(shape[:-1]): 107*da0073e9SAndroid Build Coastguard Worker strides.append(dim * strides[-1]) 108*da0073e9SAndroid Build Coastguard Worker return list(reversed(strides)) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Workerclass FakeSymbolicTensor(torch.Tensor): 112*da0073e9SAndroid Build Coastguard Worker @staticmethod 113*da0073e9SAndroid Build Coastguard Worker def __new__( 114*da0073e9SAndroid Build Coastguard Worker cls, 115*da0073e9SAndroid Build Coastguard Worker sym_shape, 116*da0073e9SAndroid Build Coastguard Worker sym_strides, 117*da0073e9SAndroid Build Coastguard Worker dtype, 118*da0073e9SAndroid Build Coastguard Worker layout, 119*da0073e9SAndroid Build Coastguard Worker requires_grad, 120*da0073e9SAndroid Build Coastguard Worker device, 121*da0073e9SAndroid Build Coastguard Worker storage_offset=0, 122*da0073e9SAndroid Build Coastguard Worker ): 123*da0073e9SAndroid Build Coastguard Worker # TODO: this is wrong in general 124*da0073e9SAndroid Build Coastguard Worker sym_stride = create_contiguous(sym_shape) 125*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( 126*da0073e9SAndroid Build Coastguard Worker cls, 127*da0073e9SAndroid Build Coastguard Worker sym_shape, 128*da0073e9SAndroid Build Coastguard Worker sym_stride, 129*da0073e9SAndroid Build Coastguard Worker storage_offset, 130*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 131*da0073e9SAndroid Build Coastguard Worker layout=layout, 132*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 133*da0073e9SAndroid Build Coastguard Worker device=device, 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker return r 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker __torch_function__ = _disabled_torch_function_impl 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker def new_empty(self, shape): 140*da0073e9SAndroid Build Coastguard Worker return FakeSymbolicTensor( 141*da0073e9SAndroid Build Coastguard Worker shape, None, self.dtype, self.layout, self.requires_grad, self.device 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @classmethod 145*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): 146*da0073e9SAndroid Build Coastguard Worker if func_overload in meta_funcs: 147*da0073e9SAndroid Build Coastguard Worker return meta_funcs[func_overload](*args, **kwargs) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker if func_overload == torch.ops.aten.new_empty.default: 150*da0073e9SAndroid Build Coastguard Worker self = args[0] 151*da0073e9SAndroid Build Coastguard Worker shape = args[1] 152*da0073e9SAndroid Build Coastguard Worker return FakeSymbolicTensor( 153*da0073e9SAndroid Build Coastguard Worker shape, 154*da0073e9SAndroid Build Coastguard Worker self.stride(), 155*da0073e9SAndroid Build Coastguard Worker self.dtype, 156*da0073e9SAndroid Build Coastguard Worker self.layout, 157*da0073e9SAndroid Build Coastguard Worker self.requires_grad, 158*da0073e9SAndroid Build Coastguard Worker self.device, 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"operator {func_overload} not supported") 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Workerdef create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None): 165*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import ConstantSource 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker if source is None: 168*da0073e9SAndroid Build Coastguard Worker source = ConstantSource(name) 169*da0073e9SAndroid Build Coastguard Worker constraint_dims = [None] * arg.dim() 170*da0073e9SAndroid Build Coastguard Worker if dynamic_dims is None: 171*da0073e9SAndroid Build Coastguard Worker dynamic_dims = [DimDynamic.DUCK] * arg.dim() 172*da0073e9SAndroid Build Coastguard Worker ( 173*da0073e9SAndroid Build Coastguard Worker sym_shapes, 174*da0073e9SAndroid Build Coastguard Worker sym_strides, 175*da0073e9SAndroid Build Coastguard Worker sym_storage_offset, 176*da0073e9SAndroid Build Coastguard Worker ) = shape_env.create_symbolic_sizes_strides_storage_offset( 177*da0073e9SAndroid Build Coastguard Worker arg, 178*da0073e9SAndroid Build Coastguard Worker source=source, 179*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 180*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims 181*da0073e9SAndroid Build Coastguard Worker ), 182*da0073e9SAndroid Build Coastguard Worker ) 183*da0073e9SAndroid Build Coastguard Worker return FakeSymbolicTensor( 184*da0073e9SAndroid Build Coastguard Worker sym_shapes, 185*da0073e9SAndroid Build Coastguard Worker sym_strides, 186*da0073e9SAndroid Build Coastguard Worker arg.dtype, 187*da0073e9SAndroid Build Coastguard Worker arg.layout, 188*da0073e9SAndroid Build Coastguard Worker arg.requires_grad, 189*da0073e9SAndroid Build Coastguard Worker arg.device, 190*da0073e9SAndroid Build Coastguard Worker sym_storage_offset, 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Workerdef create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs): 195*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import ConstantSource 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker symbol = shape_env.create_symbol( 198*da0073e9SAndroid Build Coastguard Worker val, 199*da0073e9SAndroid Build Coastguard Worker source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), 200*da0073e9SAndroid Build Coastguard Worker dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC, 201*da0073e9SAndroid Build Coastguard Worker constraint_dim=None, 202*da0073e9SAndroid Build Coastguard Worker **kwargs, 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker return cls(SymNode(symbol, shape_env, pytype, hint=val)) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker# TODO: default duck to False 208*da0073e9SAndroid Build Coastguard Workerdef create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt: 209*da0073e9SAndroid Build Coastguard Worker return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Workerdef create_symbool(shape_env, b: bool) -> SymBool: 213*da0073e9SAndroid Build Coastguard Worker return create_symtype(SymBool, bool, shape_env, b) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerdef create_symfloat(shape_env, f: float) -> SymFloat: 217*da0073e9SAndroid Build Coastguard Worker return create_symtype(SymFloat, float, shape_env, f) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo( 221*da0073e9SAndroid Build Coastguard Worker "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" 222*da0073e9SAndroid Build Coastguard Worker) 223*da0073e9SAndroid Build Coastguard Workerclass TestPySymInt(TestCase): 224*da0073e9SAndroid Build Coastguard Worker def test_arith_ops(self): 225*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 226*da0073e9SAndroid Build Coastguard Worker symints = [] 227*da0073e9SAndroid Build Coastguard Worker for i in range(2, 5): 228*da0073e9SAndroid Build Coastguard Worker symints.append((i, create_symint(shape_env, i))) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker ops = [ 231*da0073e9SAndroid Build Coastguard Worker operator.add, 232*da0073e9SAndroid Build Coastguard Worker operator.sub, 233*da0073e9SAndroid Build Coastguard Worker operator.floordiv, 234*da0073e9SAndroid Build Coastguard Worker operator.mul, 235*da0073e9SAndroid Build Coastguard Worker operator.mod, 236*da0073e9SAndroid Build Coastguard Worker ] 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker for op in ops: 239*da0073e9SAndroid Build Coastguard Worker for args in itertools.permutations(symints, 2): 240*da0073e9SAndroid Build Coastguard Worker if not isinstance(args[0][1], int) and ( 241*da0073e9SAndroid Build Coastguard Worker (op != operator.mod or op != operator.floordiv) and args[1][0] != 0 242*da0073e9SAndroid Build Coastguard Worker ): 243*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 244*da0073e9SAndroid Build Coastguard Worker op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]) 245*da0073e9SAndroid Build Coastguard Worker ) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def test_reverse_arith_ops(self): 248*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker a = create_symint(shape_env, 2) 251*da0073e9SAndroid Build Coastguard Worker self.assertTrue(5 // a == 5 // 2) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker a = create_symint(shape_env, 2) 254*da0073e9SAndroid Build Coastguard Worker self.assertTrue(5 * a == 5 * 2) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def test_sympify_symint(self): 257*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 258*da0073e9SAndroid Build Coastguard Worker a = create_symint(shape_env, 2) 259*da0073e9SAndroid Build Coastguard Worker self.assertIs(sympy.sympify(a), a.node.expr) 260*da0073e9SAndroid Build Coastguard Worker b = create_symfloat(shape_env, 3.0) 261*da0073e9SAndroid Build Coastguard Worker self.assertIs(sympy.sympify(b), b.node.expr) 262*da0073e9SAndroid Build Coastguard Worker c = create_symbool(shape_env, True) 263*da0073e9SAndroid Build Coastguard Worker self.assertIs(sympy.sympify(c), c.node.expr) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_roundtrip(self): 266*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 267*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not isinstance(x.shape[0], SymNode)) 270*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x.shape[0], SymInt)) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.shape[0] == 5) 273*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.shape[1] == 4) 274*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.shape[2], 3) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size()[0], 5) 277*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size()[1], 4) 278*da0073e9SAndroid Build Coastguard Worker # Should be simplifiable to an integer. 279*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/pull/107492 280*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x.size()[1], SymInt)) 281*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 282*da0073e9SAndroid Build Coastguard Worker isinstance(x.size()[1].node.maybe_as_int(), int) 283*da0073e9SAndroid Build Coastguard Worker ) # due to guard above 284*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size()[2] == 3) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size(0) == 5) 287*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size(1) == 4) 288*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.size(2) == 3) 289*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x.size(2), SymInt)) 290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int)) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env) 293*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(y.storage_offset(), SymInt)) 294*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.storage_offset() == 12) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def test_binary(self): 297*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 298*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 299*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker z = x + y 302*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 303*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 304*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker # broadcasting 307*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env) 308*da0073e9SAndroid Build Coastguard Worker z = x + y 309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 310*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 311*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def test_symint_args(self): 314*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 315*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 316*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) 317*da0073e9SAndroid Build Coastguard Worker LAST_DIM = 2 318*da0073e9SAndroid Build Coastguard Worker z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) 319*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == y.shape[2]) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker # arithmetic expr with two symints 322*da0073e9SAndroid Build Coastguard Worker z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) 323*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 2) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker # arithmetic expr with a symint and python int 326*da0073e9SAndroid Build Coastguard Worker z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) 327*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 2) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def test_symint_vargs(self): 330*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 331*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 332*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker # varargs 335*da0073e9SAndroid Build Coastguard Worker z = y.expand(x.shape[0], y.shape[1], x.shape[2]) 336*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 337*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 338*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # shape list 341*da0073e9SAndroid Build Coastguard Worker z = y.expand((x.shape[0], y.shape[1], x.shape[2])) 342*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 343*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker # mixed python symints and ints 347*da0073e9SAndroid Build Coastguard Worker z = y.expand(x.shape[0], y.shape[1], 3) 348*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 349*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 350*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker # mixed python symints and ints in a list 353*da0073e9SAndroid Build Coastguard Worker z = y.expand((x.shape[0], y.shape[1], 3)) 354*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 355*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker # mixed python symints and ints 359*da0073e9SAndroid Build Coastguard Worker z = y.expand(5, y.shape[1], x.shape[2]) 360*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 361*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker # mixed python ints and symints in a list 365*da0073e9SAndroid Build Coastguard Worker z = y.expand((5, y.shape[1], x.shape[2])) 366*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[0] == 5) 367*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[1] == 4) 368*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.shape[2] == 3) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker z = y.expand((y.shape[1],)) 371*da0073e9SAndroid Build Coastguard Worker z = y.expand(y.shape[1]) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_stride(self): 374*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 375*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) 376*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.stride()[0], SymInt) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_size_expressions(self): 379*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 380*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5), shape_env) 381*da0073e9SAndroid Build Coastguard Worker expand_x = x.expand(x.shape[0], x.shape[0]) 382*da0073e9SAndroid Build Coastguard Worker if expand_x.shape[0] > 3: 383*da0073e9SAndroid Build Coastguard Worker result = expand_x + expand_x 384*da0073e9SAndroid Build Coastguard Worker else: 385*da0073e9SAndroid Build Coastguard Worker result = expand_x + expand_x 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker gt_op, _bt = shape_env.guards[-1] 388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) 389*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) 390*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) 391*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def test_floordiv_static(self): 394*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 395*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 8) 396*da0073e9SAndroid Build Coastguard Worker # This was extracted from 397*da0073e9SAndroid Build Coastguard Worker # python test/inductor/test_cuda_cpp_wrapper.py -k 398*da0073e9SAndroid Build Coastguard Worker # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper 399*da0073e9SAndroid Build Coastguard Worker bool(s0 % 2 == 0) 400*da0073e9SAndroid Build Coastguard Worker bool(s0 % (s0 // 2) == 0) 401*da0073e9SAndroid Build Coastguard Worker bool(2 * (s0 // 2) == s0) 402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2)) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def test_numel(self): 405*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 406*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5), shape_env) 407*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.numel(), torch.SymInt) 408*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.numel(x), torch.SymInt) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 3) 411*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.numel(), int) 412*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.numel(x), int) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker def test_int_to_float(self): 415*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 416*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5), shape_env) 417*da0073e9SAndroid Build Coastguard Worker r = torch.sym_float(x.shape[0]) 418*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymFloat, msg=type(r)) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def test_aten_ops(self): 421*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 422*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x", torch.randn(5), shape_env) 423*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 426*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env) 427*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker def test_fx_trace_intlist(self): 430*da0073e9SAndroid Build Coastguard Worker class CustomModule(torch.nn.Module): 431*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 432*da0073e9SAndroid Build Coastguard Worker bs, c, h, w = x.shape 433*da0073e9SAndroid Build Coastguard Worker return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker m = CustomModule() 436*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 3, 4, 4) 437*da0073e9SAndroid Build Coastguard Worker # should not TypeError: pad(): argument 'pad' (position 2) must be 438*da0073e9SAndroid Build Coastguard Worker # tuple of ints, not tuple 439*da0073e9SAndroid Build Coastguard Worker torch.fx.symbolic_trace(m) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def test_meta_symint(self): 442*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 443*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 2) 444*da0073e9SAndroid Build Coastguard Worker r = torch.empty(a0, device="meta") 445*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r.shape[0], SymInt) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker def test_guard_int(self): 448*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 449*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 2) 450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(guard_int(a0), 2) 451*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker def test_prefer_deferred_runtime_assertions_over_guards(self): 454*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) 455*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 2) 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(guard_int(s0), 2) 457*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) 460*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 2) 461*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(s0 == 2)) 462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 463*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 464*da0073e9SAndroid Build Coastguard Worker str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]), 465*da0073e9SAndroid Build Coastguard Worker """[Eq(s0, 2)]""", 466*da0073e9SAndroid Build Coastguard Worker ) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker def test_sym_int(self): 469*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 470*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 471*da0073e9SAndroid Build Coastguard Worker r = sym_int(a0) 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 5) 473*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 474*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""") 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker a1 = create_symint(shape_env, 7) 477*da0073e9SAndroid Build Coastguard Worker r = sym_int(a1 / 2) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(guard_int(r), 3) 479*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 480*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 481*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker a3 = create_symint(shape_env, 3) 485*da0073e9SAndroid Build Coastguard Worker r = sym_int(2.0 * torch.sym_float(a3)) 486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(guard_int(r), 6) 487*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 488*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 489*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def test_sym_sqrt(self): 493*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 494*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 4) 495*da0073e9SAndroid Build Coastguard Worker r = torch._sym_sqrt(a0) 496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 2) 497*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymFloat, msg=type(r)) 498*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 499*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" 500*da0073e9SAndroid Build Coastguard Worker ) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker def test_sym_floor(self): 503*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 504*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 505*da0073e9SAndroid Build Coastguard Worker r = math.floor(a0 / 2) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 2) 507*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 508*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 509*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[0][0]), 510*da0073e9SAndroid Build Coastguard Worker """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", 511*da0073e9SAndroid Build Coastguard Worker ) 512*da0073e9SAndroid Build Coastguard Worker r = math.floor(3.0 * a0) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 15) 514*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 515*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 516*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[1][0]), 517*da0073e9SAndroid Build Coastguard Worker """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", 518*da0073e9SAndroid Build Coastguard Worker ) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def test_sym_trunc(self): 521*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 522*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 523*da0073e9SAndroid Build Coastguard Worker r = math.trunc(a0 / 2) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 2) 525*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 526*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 527*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker r = torch.sym_int(torch.sym_sqrt(a0)) 530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 2) 531*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 532*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 533*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def test_sym_ceil(self): 537*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 538*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 539*da0073e9SAndroid Build Coastguard Worker r = math.ceil(a0 / 2) 540*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 3) 541*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 542*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 543*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[0][0]), 544*da0073e9SAndroid Build Coastguard Worker """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker r1 = 3.0 * a0 547*da0073e9SAndroid Build Coastguard Worker r = math.floor(r1) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, 15) 549*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r, torch.SymInt, msg=type(r)) 550*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 551*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[1][0]), 552*da0073e9SAndroid Build Coastguard Worker """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", 553*da0073e9SAndroid Build Coastguard Worker ) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker def test_sym_ite(self): 556*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 557*da0073e9SAndroid Build Coastguard Worker t = create_symint(shape_env, 5) 558*da0073e9SAndroid Build Coastguard Worker f = create_symint(shape_env, 4) 559*da0073e9SAndroid Build Coastguard Worker b1 = True 560*da0073e9SAndroid Build Coastguard Worker r1 = torch.sym_ite(b1, t, f) 561*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r1 is t) 562*da0073e9SAndroid Build Coastguard Worker b2 = False 563*da0073e9SAndroid Build Coastguard Worker r2 = torch.sym_ite(b2, t, f) 564*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2 is f) 565*da0073e9SAndroid Build Coastguard Worker b3 = t == 5 566*da0073e9SAndroid Build Coastguard Worker r3 = torch.sym_ite(b3, t, f) 567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r3, 5) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(t), type(r3)) 570*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 571*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[0][0]), 572*da0073e9SAndroid Build Coastguard Worker """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""", 573*da0073e9SAndroid Build Coastguard Worker ) 574*da0073e9SAndroid Build Coastguard Worker b4 = f == 5 575*da0073e9SAndroid Build Coastguard Worker r4 = torch.sym_ite(b4, t, f) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 1) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r4, 4) 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(f), type(r4)) 579*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 580*da0073e9SAndroid Build Coastguard Worker str(shape_env.guards[1][0]), 581*da0073e9SAndroid Build Coastguard Worker """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""", 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def test_tracing_sym_ite(self): 585*da0073e9SAndroid Build Coastguard Worker def f(x): 586*da0073e9SAndroid Build Coastguard Worker b = x.shape[0] == 5 587*da0073e9SAndroid Build Coastguard Worker ret = torch.sym_ite(b, x.shape[0], x.shape[1]) 588*da0073e9SAndroid Build Coastguard Worker return ret 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5)) 591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gm.shape_env.guards), 0) 592*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 593*da0073e9SAndroid Build Coastguard Worker gm.code.strip(), 594*da0073e9SAndroid Build Coastguard Worker """\ 595*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1): 596*da0073e9SAndroid Build Coastguard Worker sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 597*da0073e9SAndroid Build Coastguard Worker eq = sym_size_int == 5 598*da0073e9SAndroid Build Coastguard Worker sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None 599*da0073e9SAndroid Build Coastguard Worker sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None 600*da0073e9SAndroid Build Coastguard Worker return sym_ite""", 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker r1 = gm(torch.ones(4, 5)) 603*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r1, int) 604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, 5) 605*da0073e9SAndroid Build Coastguard Worker r2 = gm(torch.ones(5, 4)) 606*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r2, int) 607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2, 5) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def test_int_conversion(self): 610*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 611*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 2) 612*da0073e9SAndroid Build Coastguard Worker int(a0) 613*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_guard(self): 616*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 617*da0073e9SAndroid Build Coastguard Worker s0 = shape_env.create_unbacked_symint() 618*da0073e9SAndroid Build Coastguard Worker self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker def test_data_dependent_guard_propagate_real_tensors(self): 621*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 622*da0073e9SAndroid Build Coastguard Worker s0 = shape_env.create_unbacked_symint() 623*da0073e9SAndroid Build Coastguard Worker shape_env.set_unbacked_var_to_val(s0.node.expr, 0) 624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bool(s0 == 0), True) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker def test_expect_true_basic(self): 627*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 628*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 629*da0073e9SAndroid Build Coastguard Worker i0_sym = i0.node.expr 630*da0073e9SAndroid Build Coastguard Worker # This doesn't error 631*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i0 == 0)) 632*da0073e9SAndroid Build Coastguard Worker # This generates a deferred runtime assert via replacement 633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape_env.replacements[i0_sym], 0) 634*da0073e9SAndroid Build Coastguard Worker # After expecting true, guards now resolve given the runtime assert 635*da0073e9SAndroid Build Coastguard Worker bool(i0 == 0) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def test_expect_true_with_s0(self): 638*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 639*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 5) 640*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 641*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i0 < s0)) 642*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 643*da0073e9SAndroid Build Coastguard Worker str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), 644*da0073e9SAndroid Build Coastguard Worker """[u0 < s0]""", 645*da0073e9SAndroid Build Coastguard Worker ) 646*da0073e9SAndroid Build Coastguard Worker self.assertTrue(i0 < s0) 647*da0073e9SAndroid Build Coastguard Worker self.assertTrue(i0 != s0) 648*da0073e9SAndroid Build Coastguard Worker self.assertFalse(i0 > s0) 649*da0073e9SAndroid Build Coastguard Worker self.assertFalse(i0 >= s0) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker def test_expect_true_prefer_later(self): 652*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 653*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 654*da0073e9SAndroid Build Coastguard Worker i1 = shape_env.create_unbacked_symint() 655*da0073e9SAndroid Build Coastguard Worker i1_sym = i1.node.expr 656*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i0 + i1 == 10)) 657*da0073e9SAndroid Build Coastguard Worker # Importantly, this is put in i1, not i0! 658*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 659*da0073e9SAndroid Build Coastguard Worker str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]), 660*da0073e9SAndroid Build Coastguard Worker """[Eq(u0 + u1, 10)]""", 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker self.assertTrue(i0 + i1 == 10) 663*da0073e9SAndroid Build Coastguard Worker # NB: We currently don't support deriving that we can substitute 664*da0073e9SAndroid Build Coastguard Worker # i0 + i1 with 10; maybe we should, but this means our rewriting 665*da0073e9SAndroid Build Coastguard Worker # system is no longer confluent (it's probably OK though, because 666*da0073e9SAndroid Build Coastguard Worker # you're unlikely to get other equalities like this on the 667*da0073e9SAndroid Build Coastguard Worker # unbacked SymInts.) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_unbacked_substitution(self): 670*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 671*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 672*da0073e9SAndroid Build Coastguard Worker i1 = shape_env.create_unbacked_symint() 673*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i0) 674*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i1) 675*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i0 == i1 * 4)) 676*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(i0), """u0""") 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker i2 = shape_env.create_unbacked_symint() 679*da0073e9SAndroid Build Coastguard Worker i3 = shape_env.create_unbacked_symint() 680*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i2) 681*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i3) 682*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i2 * 4 == i3)) 683*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(i3), """u3""") 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker def test_avoid_unbacked_substitution(self): 686*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 687*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 688*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i0) 689*da0073e9SAndroid Build Coastguard Worker i1 = shape_env.create_unbacked_symint() 690*da0073e9SAndroid Build Coastguard Worker _constrain_range_for_size(i1) 691*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(i0 == 10 - i1)) 692*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(i0), """u0""") 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def test_expect_true_double_digits(self): 695*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 696*da0073e9SAndroid Build Coastguard Worker ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(ia[-1]), "u10") 698*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(sum(ia) == 20)) 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1) 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker def test_expect_true_refine_range(self): 702*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 703*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 704*da0073e9SAndroid Build Coastguard Worker [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 705*da0073e9SAndroid Build Coastguard Worker ): 706*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 707*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 708*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(rel(i0))) 709*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 3)) 710*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 4)) 711*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 5)) 712*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 6)) 713*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 > 4)) 714*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 >= 5)) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 717*da0073e9SAndroid Build Coastguard Worker [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 718*da0073e9SAndroid Build Coastguard Worker ): 719*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 720*da0073e9SAndroid Build Coastguard Worker i0 = shape_env.create_unbacked_symint() 721*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect_true(rel(i0))) 722*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 2)) 723*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 3)) 724*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 4)) 725*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 5)) 726*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 < 4)) 727*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 <= 5)) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker def test_guard_refine_range(self): 730*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 731*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 732*da0073e9SAndroid Build Coastguard Worker [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 733*da0073e9SAndroid Build Coastguard Worker ): 734*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 735*da0073e9SAndroid Build Coastguard Worker i0 = create_symint(shape_env, 10, duck=False) 736*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bool(rel(i0))) 737*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 3)) 738*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 4)) 739*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 5)) 740*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 6)) 741*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 > 4)) 742*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 >= 5)) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 745*da0073e9SAndroid Build Coastguard Worker [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 746*da0073e9SAndroid Build Coastguard Worker ): 747*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 748*da0073e9SAndroid Build Coastguard Worker i0 = create_symint(shape_env, 2, duck=False) 749*da0073e9SAndroid Build Coastguard Worker self.assertFalse(bool(rel(i0))) 750*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 3)) 751*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 4)) 752*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 5)) 753*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 6)) 754*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 <= 4)) 755*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 < 5)) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 758*da0073e9SAndroid Build Coastguard Worker [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 759*da0073e9SAndroid Build Coastguard Worker ): 760*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 761*da0073e9SAndroid Build Coastguard Worker i0 = create_symint(shape_env, 2, duck=False) 762*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bool(rel(i0))) 763*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 2)) 764*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 3)) 765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 4)) 766*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 5)) 767*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 < 4)) 768*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 <= 3)) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker for i, rel in enumerate( 771*da0073e9SAndroid Build Coastguard Worker [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 772*da0073e9SAndroid Build Coastguard Worker ): 773*da0073e9SAndroid Build Coastguard Worker with self.subTest(f"i = {i}"): 774*da0073e9SAndroid Build Coastguard Worker i0 = create_symint(shape_env, 10, duck=False) 775*da0073e9SAndroid Build Coastguard Worker self.assertFalse(bool(rel(i0))) 776*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 2)) 777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 != 3)) 778*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 4)) 779*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(i0 != 5)) 780*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 >= 4)) 781*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(i0 > 3)) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker def test_mul_int_oo_nan(self): 784*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 785*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 5, duck=False) 786*da0073e9SAndroid Build Coastguard Worker s1 = create_symint(shape_env, 6, duck=False) 787*da0073e9SAndroid Build Coastguard Worker s2 = create_symint(shape_env, 5, duck=False) 788*da0073e9SAndroid Build Coastguard Worker bool(s0 * (s1 // s0) == s2) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker def test_non_overlapping_and_dense(self): 791*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 792*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 793*da0073e9SAndroid Build Coastguard Worker r = torch.empty_strided((a0, 7), (1, a0), device="meta") 794*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker def test_non_overlapping_and_dense_unbacked(self): 797*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 798*da0073e9SAndroid Build Coastguard Worker u0 = shape_env.create_unbacked_symint() 799*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u0) 800*da0073e9SAndroid Build Coastguard Worker cf = torch.ops.aten.is_non_overlapping_and_dense.default 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1) 803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1) 804*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) 805*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1) 808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1) 809*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) 810*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker Max = torch.sym_max 813*da0073e9SAndroid Build Coastguard Worker # NB: This only works because we're able to determine this tensor is 814*da0073e9SAndroid Build Coastguard Worker # contiguous. transpose(0, 1) makes it stop working 815*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 816*da0073e9SAndroid Build Coastguard Worker cf( 817*da0073e9SAndroid Build Coastguard Worker torch.empty_strided( 818*da0073e9SAndroid Build Coastguard Worker (2, 3, 1, u0), 819*da0073e9SAndroid Build Coastguard Worker (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), 820*da0073e9SAndroid Build Coastguard Worker device="meta", 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker ) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker def test_numpy_sym_max(self): 826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.int64(10), 12), 12) 827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.int64(12), 10), 12) 828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5) 829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0) 830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0) 831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def test_numpy_sym_min(self): 834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.int64(10), 12), 10) 835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.int64(12), 10), 10) 836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0) 837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0) 839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def test_debug_has_internal_overlap_unbacked(self): 842*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 843*da0073e9SAndroid Build Coastguard Worker u0 = shape_env.create_unbacked_symint() 844*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u0) 845*da0073e9SAndroid Build Coastguard Worker cf = torch._debug_has_internal_overlap 846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) 847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) 848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) 849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0) 850*da0073e9SAndroid Build Coastguard Worker Max = torch.sym_max 851*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 852*da0073e9SAndroid Build Coastguard Worker cf( 853*da0073e9SAndroid Build Coastguard Worker torch.empty_strided( 854*da0073e9SAndroid Build Coastguard Worker (2, 3, 1, u0), 855*da0073e9SAndroid Build Coastguard Worker (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), 856*da0073e9SAndroid Build Coastguard Worker device="meta", 857*da0073e9SAndroid Build Coastguard Worker ) 858*da0073e9SAndroid Build Coastguard Worker ), 859*da0073e9SAndroid Build Coastguard Worker 0, 860*da0073e9SAndroid Build Coastguard Worker ) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker # Wobbling these to zero is OK too 863*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2) 864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker def test_specialize_zero_one(self): 867*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(specialize_zero_one=True) 868*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 869*da0073e9SAndroid Build Coastguard Worker assert a0 != 1 870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(specialize_zero_one=False) 873*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 874*da0073e9SAndroid Build Coastguard Worker assert a0 != 1 875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 1) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker def test_duck_shape(self): 878*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(duck_shape=True) 879*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 880*da0073e9SAndroid Build Coastguard Worker a1 = create_symint(shape_env, 5) 881*da0073e9SAndroid Build Coastguard Worker assert a0 == a1 882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(duck_shape=False) 885*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 886*da0073e9SAndroid Build Coastguard Worker a1 = create_symint(shape_env, 5) 887*da0073e9SAndroid Build Coastguard Worker assert a0 == a1 888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 1) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker def test_int_bool(self): 891*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/95981 892*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv(duck_shape=True) 893*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 5) 894*da0073e9SAndroid Build Coastguard Worker assert a0 895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker def test_symint_as_scalar(self): 898*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 899*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 2) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker sym_int_encountered = False 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker class TestSymInt(TorchDispatchMode): 904*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 905*da0073e9SAndroid Build Coastguard Worker assert func == torch.ops.aten.add.Tensor 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker nonlocal sym_int_encountered 908*da0073e9SAndroid Build Coastguard Worker # WARNING: do not do identity tests on the outer 909*da0073e9SAndroid Build Coastguard Worker # SymInt/SymFloat, they are NOT STABLE 910*da0073e9SAndroid Build Coastguard Worker sym_int_encountered = kwargs["alpha"].node is a0.node 911*da0073e9SAndroid Build Coastguard Worker kwargs["alpha"] = 0 912*da0073e9SAndroid Build Coastguard Worker return func(*args) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4, 4]) 915*da0073e9SAndroid Build Coastguard Worker with TestSymInt(): 916*da0073e9SAndroid Build Coastguard Worker y = torch.add(x, x, alpha=a0) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sym_int_encountered) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker def test_deepcopy(self): 921*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 922*da0073e9SAndroid Build Coastguard Worker a0 = create_symint(shape_env, 2) 923*da0073e9SAndroid Build Coastguard Worker assert a0 < 4 924*da0073e9SAndroid Build Coastguard Worker new_shape_env = copy.deepcopy(shape_env) 925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(new_shape_env.guards), 1) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker def test_print_readable_with_symints(self): 928*da0073e9SAndroid Build Coastguard Worker def f(a, b): 929*da0073e9SAndroid Build Coastguard Worker dim0 = a.shape[0] + b.shape[0] 930*da0073e9SAndroid Build Coastguard Worker dim1 = a.shape[1] + b.shape[1] 931*da0073e9SAndroid Build Coastguard Worker d = a.new_empty(dim0, dim1) 932*da0073e9SAndroid Build Coastguard Worker d = torch.ops.aten.native_dropout(d, 0.5, train=True) 933*da0073e9SAndroid Build Coastguard Worker return d 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) 936*da0073e9SAndroid Build Coastguard Worker out = fx_g.print_readable(print_output=False) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 939*da0073e9SAndroid Build Coastguard Worker out.strip(), 940*da0073e9SAndroid Build Coastguard Worker """\ 941*da0073e9SAndroid Build Coastguard Workerclass f(torch.nn.Module): 942*da0073e9SAndroid Build Coastguard Worker def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"): 943*da0073e9SAndroid Build Coastguard Worker # No stacktrace found for following nodes 944*da0073e9SAndroid Build Coastguard Worker sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0) 945*da0073e9SAndroid Build Coastguard Worker sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0) 946*da0073e9SAndroid Build Coastguard Worker add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None 947*da0073e9SAndroid Build Coastguard Worker sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1) 948*da0073e9SAndroid Build Coastguard Worker sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None 949*da0073e9SAndroid Build Coastguard Worker add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None 950*da0073e9SAndroid Build Coastguard Worker new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None 951*da0073e9SAndroid Build Coastguard Worker native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None 952*da0073e9SAndroid Build Coastguard Worker getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0] 953*da0073e9SAndroid Build Coastguard Worker getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None 954*da0073e9SAndroid Build Coastguard Worker return (getitem, getitem_1)""", # noqa: B950 955*da0073e9SAndroid Build Coastguard Worker ) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker def test_statically_known_true(self): 958*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 959*da0073e9SAndroid Build Coastguard Worker s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5)) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker # Statically known true 962*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(True)) 963*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(s2 == s2)) 964*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(s2 * s3 > s3)) 965*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true(s3 * s4 > s4)) 966*da0073e9SAndroid Build Coastguard Worker self.assertTrue(statically_known_true((s3 + s3) % 2 == 0)) 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker # Statically known false 969*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(False)) 970*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s3 * s4 <= s4)) 971*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true((s3 + s3) % 2 == 1)) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker # True for hints, but not known statically 974*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s2 + s2 == s4)) 975*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s4 % s2 == 0)) 976*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s2 != s3)) 977*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s3 * s4 > s2)) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker # False for hints, but not known statically 980*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s2 == s3)) 981*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s2 > s3)) 982*da0073e9SAndroid Build Coastguard Worker self.assertFalse(statically_known_true(s3 + s3 == s4)) 983*da0073e9SAndroid Build Coastguard Worker 984*da0073e9SAndroid Build Coastguard Worker # No guards should be generated 985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(shape_env.guards), 0) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker def test_ephemeral_source_simplification(self): 988*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import EphemeralSource 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker # For full robustness, ensure the ephemeral source symbols are simplified out regardless 991*da0073e9SAndroid Build Coastguard Worker # of construction order or check order. 992*da0073e9SAndroid Build Coastguard Worker for construct_ephemeral_first, x_first_in_check in itertools.product( 993*da0073e9SAndroid Build Coastguard Worker [False, True], [False, True] 994*da0073e9SAndroid Build Coastguard Worker ): 995*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 996*da0073e9SAndroid Build Coastguard Worker shape = (5, 10) 997*da0073e9SAndroid Build Coastguard Worker dynamic_dims = [DimDynamic.DYNAMIC for _ in shape] 998*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor( 999*da0073e9SAndroid Build Coastguard Worker "x", 1000*da0073e9SAndroid Build Coastguard Worker torch.randn(*shape), 1001*da0073e9SAndroid Build Coastguard Worker shape_env, 1002*da0073e9SAndroid Build Coastguard Worker source=(EphemeralSource() if construct_ephemeral_first else None), 1003*da0073e9SAndroid Build Coastguard Worker dynamic_dims=dynamic_dims, 1004*da0073e9SAndroid Build Coastguard Worker ) 1005*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor( 1006*da0073e9SAndroid Build Coastguard Worker "y", 1007*da0073e9SAndroid Build Coastguard Worker torch.randn(*shape), 1008*da0073e9SAndroid Build Coastguard Worker shape_env, 1009*da0073e9SAndroid Build Coastguard Worker source=(EphemeralSource() if not construct_ephemeral_first else None), 1010*da0073e9SAndroid Build Coastguard Worker dynamic_dims=dynamic_dims, 1011*da0073e9SAndroid Build Coastguard Worker ) 1012*da0073e9SAndroid Build Coastguard Worker t_with_ephemeral = x if construct_ephemeral_first else y 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker def _get_ephemeral_source_symbols(t): 1015*da0073e9SAndroid Build Coastguard Worker return [ 1016*da0073e9SAndroid Build Coastguard Worker s.node.expr 1017*da0073e9SAndroid Build Coastguard Worker for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),)) 1018*da0073e9SAndroid Build Coastguard Worker if isinstance(s, torch.SymInt) 1019*da0073e9SAndroid Build Coastguard Worker and s.node.expr in shape_env.var_to_sources 1020*da0073e9SAndroid Build Coastguard Worker and any( 1021*da0073e9SAndroid Build Coastguard Worker source.is_ephemeral() 1022*da0073e9SAndroid Build Coastguard Worker for source in shape_env.var_to_sources[s.node.expr] 1023*da0073e9SAndroid Build Coastguard Worker ) 1024*da0073e9SAndroid Build Coastguard Worker ] 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker # these checks should simplify out the ephemeral symbols, regardless of the 1027*da0073e9SAndroid Build Coastguard Worker # ordering x == y or y == x 1028*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0) 1029*da0073e9SAndroid Build Coastguard Worker if x_first_in_check: 1030*da0073e9SAndroid Build Coastguard Worker torch._check(x.size() == y.size()) 1031*da0073e9SAndroid Build Coastguard Worker torch._check(x.stride() == y.stride()) 1032*da0073e9SAndroid Build Coastguard Worker torch._check(x.storage_offset() == y.storage_offset()) 1033*da0073e9SAndroid Build Coastguard Worker else: 1034*da0073e9SAndroid Build Coastguard Worker torch._check(y.size() == x.size()) 1035*da0073e9SAndroid Build Coastguard Worker torch._check(y.stride() == x.stride()) 1036*da0073e9SAndroid Build Coastguard Worker torch._check(y.storage_offset() == x.storage_offset()) 1037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker def test_ephemeral_source_unified_with_non_ephemeral_source(self): 1040*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import EphemeralSource 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker for construct_ephemeral_first in (False, True): 1043*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1044*da0073e9SAndroid Build Coastguard Worker shape = (5, 10) 1045*da0073e9SAndroid Build Coastguard Worker # use duck sizing here to ensure symbol reuse across x and y 1046*da0073e9SAndroid Build Coastguard Worker duck_dims = [DimDynamic.DUCK for _ in shape] 1047*da0073e9SAndroid Build Coastguard Worker x = create_symbolic_tensor( 1048*da0073e9SAndroid Build Coastguard Worker "x", 1049*da0073e9SAndroid Build Coastguard Worker torch.randn(*shape), 1050*da0073e9SAndroid Build Coastguard Worker shape_env, 1051*da0073e9SAndroid Build Coastguard Worker source=(EphemeralSource() if construct_ephemeral_first else None), 1052*da0073e9SAndroid Build Coastguard Worker dynamic_dims=duck_dims, 1053*da0073e9SAndroid Build Coastguard Worker ) 1054*da0073e9SAndroid Build Coastguard Worker y = create_symbolic_tensor( 1055*da0073e9SAndroid Build Coastguard Worker "y", 1056*da0073e9SAndroid Build Coastguard Worker torch.randn(*shape), 1057*da0073e9SAndroid Build Coastguard Worker shape_env, 1058*da0073e9SAndroid Build Coastguard Worker source=(EphemeralSource() if not construct_ephemeral_first else None), 1059*da0073e9SAndroid Build Coastguard Worker dynamic_dims=duck_dims, 1060*da0073e9SAndroid Build Coastguard Worker ) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker # regardless of construction order, non-ephemeral sources should be preferred 1063*da0073e9SAndroid Build Coastguard Worker # first in the var_to_sources list for potential guarding later on 1064*da0073e9SAndroid Build Coastguard Worker for source_list in shape_env.var_to_sources.values(): 1065*da0073e9SAndroid Build Coastguard Worker self.assertFalse(source_list[0].is_ephemeral()) 1066*da0073e9SAndroid Build Coastguard Worker 1067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(), y.size()) 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), y.stride()) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.storage_offset(), y.storage_offset()) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo( 1073*da0073e9SAndroid Build Coastguard Worker "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" 1074*da0073e9SAndroid Build Coastguard Worker) 1075*da0073e9SAndroid Build Coastguard Workerclass TestSymNumberMagicMethods(TestCase): 1076*da0073e9SAndroid Build Coastguard Worker def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): 1077*da0073e9SAndroid Build Coastguard Worker with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): 1078*da0073e9SAndroid Build Coastguard Worker return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): 1081*da0073e9SAndroid Build Coastguard Worker # Helper function 1082*da0073e9SAndroid Build Coastguard Worker # NB: don't use one as that will get specialized 1083*da0073e9SAndroid Build Coastguard Worker # TODO: We don't have to circuitously create the float, can just 1084*da0073e9SAndroid Build Coastguard Worker # create a symfloat directly 1085*da0073e9SAndroid Build Coastguard Worker seed_node = (create_symint(shape_env, 2) / 2.0).node 1086*da0073e9SAndroid Build Coastguard Worker bool_seed_node = (create_symint(shape_env, 2) == 2).node 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker def get_sym_inp(inp): 1089*da0073e9SAndroid Build Coastguard Worker # NB: this must come before int 1090*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, bool): 1091*da0073e9SAndroid Build Coastguard Worker return torch.SymBool(to_node(bool_seed_node, inp)) 1092*da0073e9SAndroid Build Coastguard Worker elif isinstance(inp, int): 1093*da0073e9SAndroid Build Coastguard Worker return torch.SymInt(to_node(seed_node, inp)) 1094*da0073e9SAndroid Build Coastguard Worker else: 1095*da0073e9SAndroid Build Coastguard Worker return torch.SymFloat(to_node(seed_node, inp)) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker if fn == "float_pow": 1098*da0073e9SAndroid Build Coastguard Worker if inp1 < 0: 1099*da0073e9SAndroid Build Coastguard Worker return 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker if fn == "pow_by_natural": 1102*da0073e9SAndroid Build Coastguard Worker if isinstance(inp1, float) or isinstance(inp2, float): 1103*da0073e9SAndroid Build Coastguard Worker return 1104*da0073e9SAndroid Build Coastguard Worker if inp2 < 0: 1105*da0073e9SAndroid Build Coastguard Worker return 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker def maybe_xfail(inp1, inp2): 1108*da0073e9SAndroid Build Coastguard Worker if fn == "sym_sqrt" and inp1 < 0: 1109*da0073e9SAndroid Build Coastguard Worker # ValueError: math domain error 1110*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((ValueError,)) 1111*da0073e9SAndroid Build Coastguard Worker elif ( 1112*da0073e9SAndroid Build Coastguard Worker fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") 1113*da0073e9SAndroid Build Coastguard Worker and inp2 == 0 1114*da0073e9SAndroid Build Coastguard Worker ): 1115*da0073e9SAndroid Build Coastguard Worker # ZeroDivisionError: division by zero 1116*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((ZeroDivisionError,)) 1117*da0073e9SAndroid Build Coastguard Worker elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: 1118*da0073e9SAndroid Build Coastguard Worker # ZeroDivisionError: 0.0 cannot be raised to a negative power 1119*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((ZeroDivisionError,)) 1120*da0073e9SAndroid Build Coastguard Worker elif ( 1121*da0073e9SAndroid Build Coastguard Worker # TODO: dear catastrophe waitress, 1122*da0073e9SAndroid Build Coastguard Worker # this doesn't work 1123*da0073e9SAndroid Build Coastguard Worker fn in ["float_pow", "pow_by_natural"] 1124*da0073e9SAndroid Build Coastguard Worker and inp1 < 0 1125*da0073e9SAndroid Build Coastguard Worker and ( 1126*da0073e9SAndroid Build Coastguard Worker type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) 1127*da0073e9SAndroid Build Coastguard Worker ) 1128*da0073e9SAndroid Build Coastguard Worker and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) 1129*da0073e9SAndroid Build Coastguard Worker ): 1130*da0073e9SAndroid Build Coastguard Worker # Complex result, which we do not support: 1131*da0073e9SAndroid Build Coastguard Worker # TypeError: Cannot convert complex to float 1132*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((RuntimeError,)) 1133*da0073e9SAndroid Build Coastguard Worker elif fn in ("lshift", "rshift") and not ( 1134*da0073e9SAndroid Build Coastguard Worker isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) 1135*da0073e9SAndroid Build Coastguard Worker ): 1136*da0073e9SAndroid Build Coastguard Worker # TypeError: unsupported operand type(s) 1137*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((TypeError,)) 1138*da0073e9SAndroid Build Coastguard Worker elif fn in ("lshift", "rshift") and inp2 < 0: 1139*da0073e9SAndroid Build Coastguard Worker # ValueError: math domain error 1140*da0073e9SAndroid Build Coastguard Worker return self.assertRaises((ValueError,)) 1141*da0073e9SAndroid Build Coastguard Worker else: 1142*da0073e9SAndroid Build Coastguard Worker return contextlib.nullcontext() 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker lambda_apply = method_to_operator(fn) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker def guard_fn(v): 1147*da0073e9SAndroid Build Coastguard Worker if type(v) in (SymBool, bool): 1148*da0073e9SAndroid Build Coastguard Worker return guard_bool(v) 1149*da0073e9SAndroid Build Coastguard Worker elif type(v) in (SymFloat, float): 1150*da0073e9SAndroid Build Coastguard Worker return guard_float(v) 1151*da0073e9SAndroid Build Coastguard Worker else: # SymInt, int 1152*da0073e9SAndroid Build Coastguard Worker return guard_int(v) 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker # Get reference result 1155*da0073e9SAndroid Build Coastguard Worker with maybe_xfail(inp1, inp2): 1156*da0073e9SAndroid Build Coastguard Worker if is_unary_fn: 1157*da0073e9SAndroid Build Coastguard Worker ref_out = lambda_apply(inp1) 1158*da0073e9SAndroid Build Coastguard Worker else: 1159*da0073e9SAndroid Build Coastguard Worker ref_out = lambda_apply(inp1, inp2) 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Worker # Symified first arg 1162*da0073e9SAndroid Build Coastguard Worker sym_inp1 = get_sym_inp(inp1) 1163*da0073e9SAndroid Build Coastguard Worker with maybe_xfail(sym_inp1, inp2): 1164*da0073e9SAndroid Build Coastguard Worker if is_unary_fn: 1165*da0073e9SAndroid Build Coastguard Worker out = lambda_apply(sym_inp1) 1166*da0073e9SAndroid Build Coastguard Worker else: 1167*da0073e9SAndroid Build Coastguard Worker out = lambda_apply(sym_inp1, inp2) 1168*da0073e9SAndroid Build Coastguard Worker if fn not in sym_node.alternate_impl_if_hinted_methods: 1169*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1170*da0073e9SAndroid Build Coastguard Worker out = guard_fn(out) 1171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker if is_unary_fn: 1174*da0073e9SAndroid Build Coastguard Worker return 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker # Symified second arg 1177*da0073e9SAndroid Build Coastguard Worker sym_inp2 = get_sym_inp(inp2) 1178*da0073e9SAndroid Build Coastguard Worker with maybe_xfail(inp1, sym_inp2): 1179*da0073e9SAndroid Build Coastguard Worker out = lambda_apply(inp1, sym_inp2) 1180*da0073e9SAndroid Build Coastguard Worker if fn not in sym_node.alternate_impl_if_hinted_methods: 1181*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1182*da0073e9SAndroid Build Coastguard Worker out = guard_fn(out) 1183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker # Symified both args 1186*da0073e9SAndroid Build Coastguard Worker with maybe_xfail(sym_inp1, sym_inp2): 1187*da0073e9SAndroid Build Coastguard Worker out = lambda_apply(sym_inp1, sym_inp2) 1188*da0073e9SAndroid Build Coastguard Worker if fn not in sym_node.alternate_impl_if_hinted_methods: 1189*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1190*da0073e9SAndroid Build Coastguard Worker out = guard_fn(out) 1191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", list(sym_node.magic_methods.keys())) 1194*da0073e9SAndroid Build Coastguard Worker def test_bool_method(self, fn): 1195*da0073e9SAndroid Build Coastguard Worker # sym_ite has its own tests 1196*da0073e9SAndroid Build Coastguard Worker if fn not in sym_node.bool_magic_methods or fn == "sym_ite": 1197*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} is non-bool") 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker is_unary_fn = fn in sym_node.unary_methods 1200*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1201*da0073e9SAndroid Build Coastguard Worker self._do_test(fn, True, False, shape_env, is_unary_fn) 1202*da0073e9SAndroid Build Coastguard Worker 1203*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", list(sym_node.magic_methods.keys())) 1204*da0073e9SAndroid Build Coastguard Worker @parametrize("first_type", ["int", "float"]) 1205*da0073e9SAndroid Build Coastguard Worker @parametrize("second_type", ["int", "float"]) 1206*da0073e9SAndroid Build Coastguard Worker def test_method(self, fn, first_type, second_type): 1207*da0073e9SAndroid Build Coastguard Worker if first_type == "float": 1208*da0073e9SAndroid Build Coastguard Worker # TODO: Hmm, this looks like we skip all floats 1209*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} is not a float magic method") 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker if ( 1212*da0073e9SAndroid Build Coastguard Worker first_type == "int" or second_type == "int" 1213*da0073e9SAndroid Build Coastguard Worker ) and fn in sym_node.only_float_magic_methods: 1214*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} is not an int method") 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker if second_type == "float" and fn in ["mod"]: 1217*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} only handles int") 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker is_unary_fn = fn in sym_node.unary_methods or fn == "round" 1220*da0073e9SAndroid Build Coastguard Worker # Second argument is ignored for unary function. So only run for one type 1221*da0073e9SAndroid Build Coastguard Worker if is_unary_fn and second_type == "float": 1222*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} is unary and already tested") 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker if fn in sym_node.bool_magic_methods: 1225*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{fn} is bool") 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker # Only floats here since these will be converted to int if necessary. 1228*da0073e9SAndroid Build Coastguard Worker # We also ignore complex and bool. 1229*da0073e9SAndroid Build Coastguard Worker values = ( 1230*da0073e9SAndroid Build Coastguard Worker 0.0, 1231*da0073e9SAndroid Build Coastguard Worker 1.0, 1232*da0073e9SAndroid Build Coastguard Worker 0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error 1233*da0073e9SAndroid Build Coastguard Worker ) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker neg_values = tuple(-x for x in values) 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker for inp1, inp2 in itertools.chain( 1238*da0073e9SAndroid Build Coastguard Worker itertools.product(values, values), 1239*da0073e9SAndroid Build Coastguard Worker itertools.product(values, neg_values), 1240*da0073e9SAndroid Build Coastguard Worker itertools.product(neg_values, values), 1241*da0073e9SAndroid Build Coastguard Worker itertools.product(neg_values, neg_values), 1242*da0073e9SAndroid Build Coastguard Worker ): 1243*da0073e9SAndroid Build Coastguard Worker if first_type == "int": 1244*da0073e9SAndroid Build Coastguard Worker inp1 = int(inp1) 1245*da0073e9SAndroid Build Coastguard Worker if second_type == "int": 1246*da0073e9SAndroid Build Coastguard Worker inp2 = int(inp2) 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker self._do_test(fn, inp1, inp2, shape_env, is_unary_fn) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker def get_constant_bool(self, val): 1253*da0073e9SAndroid Build Coastguard Worker return SymBool(torch._C._get_constant_bool_symnode(val)) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1256*da0073e9SAndroid Build Coastguard Worker def test_symint_hashing(self): 1257*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1258*da0073e9SAndroid Build Coastguard Worker hash(create_symint(shape_env, 3)) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker def test_symnode_hashing(self): 1261*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker # These all trigger specialization when hashed 1264*da0073e9SAndroid Build Coastguard Worker hash(create_symbool(shape_env, True)) 1265*da0073e9SAndroid Build Coastguard Worker # We should be passing in float here, but create_symbol currently 1266*da0073e9SAndroid Build Coastguard Worker # only supports int 1267*da0073e9SAndroid Build Coastguard Worker hash(create_symfloat(shape_env, 3.0)) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker # NestedInt (SymInt), constant SymBool, SymNode are hashable 1270*da0073e9SAndroid Build Coastguard Worker j1 = torch._C._get_nested_int(1, 1) 1271*da0073e9SAndroid Build Coastguard Worker j1_copy = torch._C._get_nested_int(1, 1) 1272*da0073e9SAndroid Build Coastguard Worker j2 = torch._C._get_nested_int(2, 1) 1273*da0073e9SAndroid Build Coastguard Worker t = self.get_constant_bool(True) 1274*da0073e9SAndroid Build Coastguard Worker t_copy = self.get_constant_bool(True) 1275*da0073e9SAndroid Build Coastguard Worker f = self.get_constant_bool(False) 1276*da0073e9SAndroid Build Coastguard Worker n = create_symint(shape_env, 3).node 1277*da0073e9SAndroid Build Coastguard Worker m = self.get_constant_bool(True).node 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 == j1_copy, True) 1280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hash(j1), hash(j1_copy)) 1281*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 == j2, False) 1282*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hash(j1), hash(j2)) 1283*da0073e9SAndroid Build Coastguard Worker self.assertIs(t == t_copy, True) 1284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hash(t), hash(t_copy)) 1285*da0073e9SAndroid Build Coastguard Worker self.assertIs(t == f, False) 1286*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(hash(t), hash(f)) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker hash(n) 1289*da0073e9SAndroid Build Coastguard Worker hash(m) 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker def test_symint_deepcopy(self): 1292*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker symnodes = (torch._C._get_nested_int(1, 1),) 1295*da0073e9SAndroid Build Coastguard Worker deepcopied_symnodes = copy.deepcopy(symnodes) 1296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(symnodes, deepcopied_symnodes) 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def test_non_symbolic_symnode(self): 1299*da0073e9SAndroid Build Coastguard Worker j1 = torch._C._get_nested_int(1, 1) 1300*da0073e9SAndroid Build Coastguard Worker j2 = torch._C._get_nested_int(1, 1) 1301*da0073e9SAndroid Build Coastguard Worker j3 = torch._C._get_nested_int(3, 1) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(j1, torch.SymInt) 1304*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(j1, int) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1307*da0073e9SAndroid Build Coastguard Worker RuntimeError, "add not supported by NestedIntSymNode" 1308*da0073e9SAndroid Build Coastguard Worker ): 1309*da0073e9SAndroid Build Coastguard Worker j1 + 3 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker self.assertFalse(j1 == 3) 1312*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "indeterminate"): 1313*da0073e9SAndroid Build Coastguard Worker self.assertFalse(3 >= j2) 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 == j1, True) 1316*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 == j2, True) 1317*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 == j3, False) 1318*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 != j3, True) 1319*da0073e9SAndroid Build Coastguard Worker self.assertIs(j1 != j2, False) 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker x = self.get_constant_bool(True) 1322*da0073e9SAndroid Build Coastguard Worker # 1323*da0073e9SAndroid Build Coastguard Worker # Unary 1324*da0073e9SAndroid Build Coastguard Worker # 1325*da0073e9SAndroid Build Coastguard Worker # op(constant SymBool) 1326*da0073e9SAndroid Build Coastguard Worker self.assertIs(x.__sym_not__(), False) 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker # 1329*da0073e9SAndroid Build Coastguard Worker # Binary 1330*da0073e9SAndroid Build Coastguard Worker # 1331*da0073e9SAndroid Build Coastguard Worker # op(constant SymBool, bool) 1332*da0073e9SAndroid Build Coastguard Worker # op(constant SymBool, constant SymBool) 1333*da0073e9SAndroid Build Coastguard Worker # op(bool, constant SymBool) 1334*da0073e9SAndroid Build Coastguard Worker self.assertIs(operator.and_(x, True), True) 1335*da0073e9SAndroid Build Coastguard Worker self.assertIs(operator.and_(x, x), True) 1336*da0073e9SAndroid Build Coastguard Worker self.assertIs(operator.and_(True, x), True) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker # op(symbolic SymBool, constant Symbool) 1339*da0073e9SAndroid Build Coastguard Worker # op(constant SymBool, symbolic Symbool) 1340*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1341*da0073e9SAndroid Build Coastguard Worker a = create_symint(shape_env, 2) 1342*da0073e9SAndroid Build Coastguard Worker b = create_symint(shape_env, 2) 1343*da0073e9SAndroid Build Coastguard Worker c = a == b # symbolic SymBool 1344*da0073e9SAndroid Build Coastguard Worker d = self.get_constant_bool(True) 1345*da0073e9SAndroid Build Coastguard Worker e = operator.and_(c, d) 1346*da0073e9SAndroid Build Coastguard Worker f = operator.and_(d, c) 1347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_symbolic(e)) 1348*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_symbolic(f)) 1349*da0073e9SAndroid Build Coastguard Worker self.assertIs(e.node.guard_bool("", 0), True) 1350*da0073e9SAndroid Build Coastguard Worker self.assertIs(f.node.guard_bool("", 0), True) 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker # Comparing sizes 1353*da0073e9SAndroid Build Coastguard Worker sz1 = torch.Size([j1, j1, j1]) 1354*da0073e9SAndroid Build Coastguard Worker sz2 = torch.Size([j1, j1, j1]) 1355*da0073e9SAndroid Build Coastguard Worker self.assertIs(sz1 == sz2, True) 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker sz1 = torch.Size([3, j1, 4]) 1358*da0073e9SAndroid Build Coastguard Worker sz2 = torch.Size([3, j2, 4]) 1359*da0073e9SAndroid Build Coastguard Worker self.assertIs(sz1 == sz2, True) 1360*da0073e9SAndroid Build Coastguard Worker self.assertIs(sz1 != sz2, False) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker def test_stride_symnode(self): 1363*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides): 1368*da0073e9SAndroid Build Coastguard Worker with FakeTensorMode(shape_env=shape_env) as fake_mode: 1369*da0073e9SAndroid Build Coastguard Worker return fake_mode.from_tensor( 1370*da0073e9SAndroid Build Coastguard Worker x, 1371*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 1372*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=dynamic_sizes, 1373*da0073e9SAndroid Build Coastguard Worker dynamic_strides=dynamic_strides, 1374*da0073e9SAndroid Build Coastguard Worker ), 1375*da0073e9SAndroid Build Coastguard Worker ) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker # check everything static 1378*da0073e9SAndroid Build Coastguard Worker t = _create_symbolic_tensor( 1379*da0073e9SAndroid Build Coastguard Worker x=torch.ones(3, 6), 1380*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[ 1381*da0073e9SAndroid Build Coastguard Worker DimDynamic.STATIC, 1382*da0073e9SAndroid Build Coastguard Worker DimDynamic.STATIC, 1383*da0073e9SAndroid Build Coastguard Worker ], 1384*da0073e9SAndroid Build Coastguard Worker dynamic_strides=[ 1385*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1386*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1387*da0073e9SAndroid Build Coastguard Worker ], 1388*da0073e9SAndroid Build Coastguard Worker ) 1389*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(isinstance(size, int) for size in t.size())) 1390*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(isinstance(stride, int) for stride in t.stride())) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker # check dynamic size but static dims 1393*da0073e9SAndroid Build Coastguard Worker t = _create_symbolic_tensor( 1394*da0073e9SAndroid Build Coastguard Worker x=torch.ones(3, 6), 1395*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[ 1396*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1397*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1398*da0073e9SAndroid Build Coastguard Worker ], 1399*da0073e9SAndroid Build Coastguard Worker dynamic_strides=[ 1400*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1401*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1402*da0073e9SAndroid Build Coastguard Worker ], 1403*da0073e9SAndroid Build Coastguard Worker ) 1404*da0073e9SAndroid Build Coastguard Worker # Expect stride to be inferred 1405*da0073e9SAndroid Build Coastguard Worker s0, s1 = t.size() 1406*da0073e9SAndroid Build Coastguard Worker s2, s3 = t.stride() 1407*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s0, torch.SymInt)) 1408*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1, torch.SymInt)) 1409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s2, torch.SymInt)) 1410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1 == s2) 1411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s3, 1) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker # Check dynamic stride but static dims 1414*da0073e9SAndroid Build Coastguard Worker t = _create_symbolic_tensor( 1415*da0073e9SAndroid Build Coastguard Worker x=torch.ones(3, 6), 1416*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[ 1417*da0073e9SAndroid Build Coastguard Worker DimDynamic.STATIC, 1418*da0073e9SAndroid Build Coastguard Worker DimDynamic.STATIC, 1419*da0073e9SAndroid Build Coastguard Worker ], 1420*da0073e9SAndroid Build Coastguard Worker dynamic_strides=[ 1421*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1422*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1423*da0073e9SAndroid Build Coastguard Worker ], 1424*da0073e9SAndroid Build Coastguard Worker ) 1425*da0073e9SAndroid Build Coastguard Worker s0, s1 = t.size() 1426*da0073e9SAndroid Build Coastguard Worker s2, s3 = t.stride() 1427*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s0, int)) 1428*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1, int)) 1429*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s2, torch.SymInt)) 1430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s3, int)) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker # Check dynamic sizes and dims, and ensure different symbol 1433*da0073e9SAndroid Build Coastguard Worker t = _create_symbolic_tensor( 1434*da0073e9SAndroid Build Coastguard Worker x=torch.ones(3, 6), 1435*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[ 1436*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1437*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1438*da0073e9SAndroid Build Coastguard Worker ], 1439*da0073e9SAndroid Build Coastguard Worker dynamic_strides=[ 1440*da0073e9SAndroid Build Coastguard Worker DimDynamic.DYNAMIC, 1441*da0073e9SAndroid Build Coastguard Worker DimDynamic.INFER_STRIDE, 1442*da0073e9SAndroid Build Coastguard Worker ], 1443*da0073e9SAndroid Build Coastguard Worker ) 1444*da0073e9SAndroid Build Coastguard Worker s0, s1 = t.size() 1445*da0073e9SAndroid Build Coastguard Worker s2, s3 = t.stride() 1446*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s0, torch.SymInt)) 1447*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1, torch.SymInt)) 1448*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s2, torch.SymInt)) 1449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s3, int)) 1450*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(s1.node.expr) != str(s2.node.expr)) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestSymNumberMagicMethods) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Workerclass TestFloorDiv(TestCase): 1457*da0073e9SAndroid Build Coastguard Worker @staticmethod 1458*da0073e9SAndroid Build Coastguard Worker def python_floordiv(x, y): 1459*da0073e9SAndroid Build Coastguard Worker return x // y 1460*da0073e9SAndroid Build Coastguard Worker 1461*da0073e9SAndroid Build Coastguard Worker @staticmethod 1462*da0073e9SAndroid Build Coastguard Worker def torch_floordiv(x, y): 1463*da0073e9SAndroid Build Coastguard Worker # Note: we fully evaluate here since FloorDiv might not always do 1464*da0073e9SAndroid Build Coastguard Worker # that. 1465*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1466*da0073e9SAndroid Build Coastguard Worker return shape_env.evaluate_expr(FloorDiv(x, y)) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker @staticmethod 1469*da0073e9SAndroid Build Coastguard Worker def yield_test_cases(values, negate=True): 1470*da0073e9SAndroid Build Coastguard Worker for x, y in values: 1471*da0073e9SAndroid Build Coastguard Worker yield (x, y) 1472*da0073e9SAndroid Build Coastguard Worker if negate: 1473*da0073e9SAndroid Build Coastguard Worker yield (-x, y) 1474*da0073e9SAndroid Build Coastguard Worker yield (x, -y) 1475*da0073e9SAndroid Build Coastguard Worker yield (-x, -y) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker def test_floordiv_float_int(self): 1478*da0073e9SAndroid Build Coastguard Worker values = ((7, 2),) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker for x, y in TestFloorDiv.yield_test_cases(values): 1481*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1482*da0073e9SAndroid Build Coastguard Worker TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) 1483*da0073e9SAndroid Build Coastguard Worker ) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker def test_floordiv_div_by_one(self): 1486*da0073e9SAndroid Build Coastguard Worker values = ((2, 1),) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker for x, y in TestFloorDiv.yield_test_cases(values): 1489*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1490*da0073e9SAndroid Build Coastguard Worker TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) 1491*da0073e9SAndroid Build Coastguard Worker ) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker def test_floordiv_simplify(self): 1494*da0073e9SAndroid Build Coastguard Worker # Tests how we simplify or evaluate FloorDiv without free variables 1495*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1496*da0073e9SAndroid Build Coastguard Worker result = 21 1497*da0073e9SAndroid Build Coastguard Worker exprs = (7 * FloorDiv(6, 2),) 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker for expr in exprs: 1500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expr, result) 1501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expr.doit(deep=False), result) 1502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expr.doit(deep=True), result) 1503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sympy.simplify(expr), result) 1504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape_env.simplify(expr), result) 1505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape_env.evaluate_expr(expr), result) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker def test_floordiv_assumptions(self): 1508*da0073e9SAndroid Build Coastguard Worker cases = ( 1509*da0073e9SAndroid Build Coastguard Worker sympy.Symbol("i1", integer=True), 1510*da0073e9SAndroid Build Coastguard Worker sympy.Symbol("i2", integer=True), 1511*da0073e9SAndroid Build Coastguard Worker ) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker for base, divisor in itertools.product(cases, repeat=2): 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker def op(): 1516*da0073e9SAndroid Build Coastguard Worker return FloorDiv(base, divisor) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker def is_complex(x): 1519*da0073e9SAndroid Build Coastguard Worker return x.is_integer is False and x.is_real is False and x.is_complex 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker if is_complex(base) or is_complex(divisor): 1522*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1523*da0073e9SAndroid Build Coastguard Worker TypeError, 1524*da0073e9SAndroid Build Coastguard Worker ( 1525*da0073e9SAndroid Build Coastguard Worker r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol'," 1526*da0073e9SAndroid Build Coastguard Worker r" expected integer or real" 1527*da0073e9SAndroid Build Coastguard Worker ), 1528*da0073e9SAndroid Build Coastguard Worker op, 1529*da0073e9SAndroid Build Coastguard Worker ) 1530*da0073e9SAndroid Build Coastguard Worker continue 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker op = op() 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker # In regular Python, x//x == 1.0 if x is a float, but FloorDiv 1535*da0073e9SAndroid Build Coastguard Worker # always returns an integer 1 when both args are the same object. 1536*da0073e9SAndroid Build Coastguard Worker # This even works for Symbols with no assumptions specified. 1537*da0073e9SAndroid Build Coastguard Worker if base is divisor: 1538*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.is_integer) 1539*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.is_real) 1540*da0073e9SAndroid Build Coastguard Worker elif base.is_integer and divisor.is_integer: 1541*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.is_integer) 1542*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.is_real) 1543*da0073e9SAndroid Build Coastguard Worker else: 1544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op.is_integer, None) 1545*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op.is_real) 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Workerclass TestDimConstraints(TestCase): 1549*da0073e9SAndroid Build Coastguard Worker def test_dim_constraints_reduce_congruences_simple(self): 1550*da0073e9SAndroid Build Coastguard Worker from sympy import Symbol 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker s = Symbol("s", positive=True, integer=True) 1553*da0073e9SAndroid Build Coastguard Worker dim_constraints = DimConstraints({}, {}, set(), {}) 1554*da0073e9SAndroid Build Coastguard Worker dim_constraints._congruences[s] = { 1555*da0073e9SAndroid Build Coastguard Worker (s / 2) % 2, 1556*da0073e9SAndroid Build Coastguard Worker (s / 2) % 8, 1557*da0073e9SAndroid Build Coastguard Worker (s / 2) % 4, 1558*da0073e9SAndroid Build Coastguard Worker s % 2, 1559*da0073e9SAndroid Build Coastguard Worker ((s / 16) + 2) % 4, 1560*da0073e9SAndroid Build Coastguard Worker } 1561*da0073e9SAndroid Build Coastguard Worker congruences = dim_constraints._reduce_congruences() 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(congruences[s], {(s + 32) % 64}) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker def test_dim_constraints_reduce_inequalities_simple(self): 1565*da0073e9SAndroid Build Coastguard Worker from sympy import Eq, Interval, Ne, Symbol 1566*da0073e9SAndroid Build Coastguard Worker from sympy.solvers.inequalities import reduce_inequalities 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker s = Symbol("s", positive=True, integer=True) 1569*da0073e9SAndroid Build Coastguard Worker exprs = { 1570*da0073e9SAndroid Build Coastguard Worker s >= 2, 1571*da0073e9SAndroid Build Coastguard Worker Ne(8 * s, 16), 1572*da0073e9SAndroid Build Coastguard Worker Ne(s / 2, 1), 1573*da0073e9SAndroid Build Coastguard Worker Ne(16 * s, 32), 1574*da0073e9SAndroid Build Coastguard Worker s < 16, 1575*da0073e9SAndroid Build Coastguard Worker Ne(s, 2), 1576*da0073e9SAndroid Build Coastguard Worker s / 2 < 16, 1577*da0073e9SAndroid Build Coastguard Worker s / 2 > 1, 1578*da0073e9SAndroid Build Coastguard Worker s / 2 >= 2, 1579*da0073e9SAndroid Build Coastguard Worker Ne(3 * s / 2, 3), 1580*da0073e9SAndroid Build Coastguard Worker } 1581*da0073e9SAndroid Build Coastguard Worker solution = reduce_inequalities(exprs, s).as_set() 1582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(solution, Interval.Ropen(4, 16)) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker exprs.add(Eq(s / 2, 4)) 1585*da0073e9SAndroid Build Coastguard Worker solution = reduce_inequalities(exprs, s).as_set() 1586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(solution, {8}) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker def test_dim_constraints_reduce_inequalities_error(self): 1589*da0073e9SAndroid Build Coastguard Worker from collections import defaultdict 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker from sympy import Symbol 1592*da0073e9SAndroid Build Coastguard Worker from sympy.solvers.inequalities import reduce_inequalities 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import ( 1595*da0073e9SAndroid Build Coastguard Worker LocalSource, 1596*da0073e9SAndroid Build Coastguard Worker TensorProperty, 1597*da0073e9SAndroid Build Coastguard Worker TensorPropertySource, 1598*da0073e9SAndroid Build Coastguard Worker ) 1599*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker s0 = Symbol("s0", positive=True, integer=True) 1602*da0073e9SAndroid Build Coastguard Worker exprs = { 1603*da0073e9SAndroid Build Coastguard Worker 4 * s0**3 - 4 * s0**2 + s0 <= 2147483647, 1604*da0073e9SAndroid Build Coastguard Worker s0 >= 2, 1605*da0073e9SAndroid Build Coastguard Worker s0**3 <= 2147483647, 1606*da0073e9SAndroid Build Coastguard Worker s0 <= 2147483647, 1607*da0073e9SAndroid Build Coastguard Worker } 1608*da0073e9SAndroid Build Coastguard Worker answer = reduce_inequalities(exprs, s0) 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Worker symbol_to_source = defaultdict(list) 1611*da0073e9SAndroid Build Coastguard Worker symbol_to_source[s0].append( 1612*da0073e9SAndroid Build Coastguard Worker TensorPropertySource( 1613*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 1614*da0073e9SAndroid Build Coastguard Worker ) 1615*da0073e9SAndroid Build Coastguard Worker ) 1616*da0073e9SAndroid Build Coastguard Worker dcp = DynamicDimConstraintPrinter(symbol_to_source, {}) 1617*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1618*da0073e9SAndroid Build Coastguard Worker AssertionError, 1619*da0073e9SAndroid Build Coastguard Worker "Unknown symbol.*created by constraints solver", 1620*da0073e9SAndroid Build Coastguard Worker ): 1621*da0073e9SAndroid Build Coastguard Worker dcp.doprint(answer) 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker def test_dim_constraints_solve_full(self): 1624*da0073e9SAndroid Build Coastguard Worker from sympy import Eq, Integer, Ne, Symbol 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import ( 1627*da0073e9SAndroid Build Coastguard Worker LocalSource, 1628*da0073e9SAndroid Build Coastguard Worker TensorProperty, 1629*da0073e9SAndroid Build Coastguard Worker TensorPropertySource, 1630*da0073e9SAndroid Build Coastguard Worker ) 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker src0 = TensorPropertySource( 1633*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 1634*da0073e9SAndroid Build Coastguard Worker ) 1635*da0073e9SAndroid Build Coastguard Worker src2 = TensorPropertySource( 1636*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0 1637*da0073e9SAndroid Build Coastguard Worker ) 1638*da0073e9SAndroid Build Coastguard Worker src3 = TensorPropertySource( 1639*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0 1640*da0073e9SAndroid Build Coastguard Worker ) 1641*da0073e9SAndroid Build Coastguard Worker src4 = TensorPropertySource( 1642*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0 1643*da0073e9SAndroid Build Coastguard Worker ) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker src1 = TensorPropertySource( 1646*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2 1647*da0073e9SAndroid Build Coastguard Worker ) 1648*da0073e9SAndroid Build Coastguard Worker src7 = TensorPropertySource( 1649*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3 1650*da0073e9SAndroid Build Coastguard Worker ) 1651*da0073e9SAndroid Build Coastguard Worker 1652*da0073e9SAndroid Build Coastguard Worker src5 = TensorPropertySource( 1653*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1 1654*da0073e9SAndroid Build Coastguard Worker ) 1655*da0073e9SAndroid Build Coastguard Worker src8 = TensorPropertySource( 1656*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1 1657*da0073e9SAndroid Build Coastguard Worker ) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker src6 = TensorPropertySource( 1660*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1 1661*da0073e9SAndroid Build Coastguard Worker ) 1662*da0073e9SAndroid Build Coastguard Worker src9 = TensorPropertySource( 1663*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1 1664*da0073e9SAndroid Build Coastguard Worker ) 1665*da0073e9SAndroid Build Coastguard Worker src10 = TensorPropertySource( 1666*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1 1667*da0073e9SAndroid Build Coastguard Worker ) 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker src11 = TensorPropertySource( 1670*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1 1671*da0073e9SAndroid Build Coastguard Worker ) 1672*da0073e9SAndroid Build Coastguard Worker src12 = TensorPropertySource( 1673*da0073e9SAndroid Build Coastguard Worker base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2 1674*da0073e9SAndroid Build Coastguard Worker ) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker s0 = Symbol("s0", positive=True, integer=True) 1677*da0073e9SAndroid Build Coastguard Worker s1 = Symbol("s1", positive=True, integer=True) 1678*da0073e9SAndroid Build Coastguard Worker s5 = Symbol("s5", positive=True, integer=True) 1679*da0073e9SAndroid Build Coastguard Worker s6 = Symbol("s6", positive=True, integer=True) 1680*da0073e9SAndroid Build Coastguard Worker symbol_to_source = { 1681*da0073e9SAndroid Build Coastguard Worker s0: [src0, src2, src3, src4], 1682*da0073e9SAndroid Build Coastguard Worker s1: [src1, src7], 1683*da0073e9SAndroid Build Coastguard Worker s5: [src5, src8], 1684*da0073e9SAndroid Build Coastguard Worker s6: [src6, src9, src10], 1685*da0073e9SAndroid Build Coastguard Worker } 1686*da0073e9SAndroid Build Coastguard Worker var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21} 1687*da0073e9SAndroid Build Coastguard Worker marked_dynamic = {s0, s1, s5, s6} 1688*da0073e9SAndroid Build Coastguard Worker dim_constraints = DimConstraints( 1689*da0073e9SAndroid Build Coastguard Worker symbol_to_source, var_to_val, marked_dynamic, {} 1690*da0073e9SAndroid Build Coastguard Worker ) 1691*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src2, s0) 1692*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src3, s0) 1693*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src4, s0) 1694*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src7, s1) 1695*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src8, s5) 1696*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src9, s6) 1697*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src10, s6) 1698*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src11, Integer(1)) 1699*da0073e9SAndroid Build Coastguard Worker dim_constraints.add_equality(src12, Integer(3)) 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s1**2 <= 2147483647) 1702*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(32 * s1**2 <= 2147483647) 1703*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s0 < 16) 1704*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(Mod(s1, 2), 0)) 1705*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) 1706*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1)) 1707*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647) 1708*da0073e9SAndroid Build Coastguard Worker dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1) 1709*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) 1710*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1711*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1712*da0073e9SAndroid Build Coastguard Worker + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1713*da0073e9SAndroid Build Coastguard Worker + 64 1714*da0073e9SAndroid Build Coastguard Worker <= 2147483647 1715*da0073e9SAndroid Build Coastguard Worker ) 1716*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) 1717*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1718*da0073e9SAndroid Build Coastguard Worker Ne( 1719*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1720*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1721*da0073e9SAndroid Build Coastguard Worker + 1, 1722*da0073e9SAndroid Build Coastguard Worker 1, 1723*da0073e9SAndroid Build Coastguard Worker ) 1724*da0073e9SAndroid Build Coastguard Worker ) 1725*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) 1726*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1727*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1728*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1729*da0073e9SAndroid Build Coastguard Worker + 1 1730*da0073e9SAndroid Build Coastguard Worker > 1 1731*da0073e9SAndroid Build Coastguard Worker ) 1732*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1733*da0073e9SAndroid Build Coastguard Worker 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1734*da0073e9SAndroid Build Coastguard Worker + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1735*da0073e9SAndroid Build Coastguard Worker + 128 1736*da0073e9SAndroid Build Coastguard Worker <= 2147483647 1737*da0073e9SAndroid Build Coastguard Worker ) 1738*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) 1739*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1740*da0073e9SAndroid Build Coastguard Worker Ne( 1741*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1742*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1743*da0073e9SAndroid Build Coastguard Worker + 1, 1744*da0073e9SAndroid Build Coastguard Worker 1, 1745*da0073e9SAndroid Build Coastguard Worker ) 1746*da0073e9SAndroid Build Coastguard Worker ) 1747*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) 1748*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1749*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1750*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1751*da0073e9SAndroid Build Coastguard Worker + 1 1752*da0073e9SAndroid Build Coastguard Worker > 1 1753*da0073e9SAndroid Build Coastguard Worker ) 1754*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1755*da0073e9SAndroid Build Coastguard Worker 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1756*da0073e9SAndroid Build Coastguard Worker + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1757*da0073e9SAndroid Build Coastguard Worker + 256 1758*da0073e9SAndroid Build Coastguard Worker <= 2147483647 1759*da0073e9SAndroid Build Coastguard Worker ) 1760*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) 1761*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1762*da0073e9SAndroid Build Coastguard Worker Ne( 1763*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1764*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1765*da0073e9SAndroid Build Coastguard Worker + 1, 1766*da0073e9SAndroid Build Coastguard Worker 1, 1767*da0073e9SAndroid Build Coastguard Worker ) 1768*da0073e9SAndroid Build Coastguard Worker ) 1769*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) 1770*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1771*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1772*da0073e9SAndroid Build Coastguard Worker + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1773*da0073e9SAndroid Build Coastguard Worker + 1 1774*da0073e9SAndroid Build Coastguard Worker > 1 1775*da0073e9SAndroid Build Coastguard Worker ) 1776*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3) 1777*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1778*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1779*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1780*da0073e9SAndroid Build Coastguard Worker + 60 1781*da0073e9SAndroid Build Coastguard Worker <= 2147483647 1782*da0073e9SAndroid Build Coastguard Worker ) 1783*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0) 1784*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1) 1785*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1786*da0073e9SAndroid Build Coastguard Worker Ne( 1787*da0073e9SAndroid Build Coastguard Worker 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1788*da0073e9SAndroid Build Coastguard Worker - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1789*da0073e9SAndroid Build Coastguard Worker + 60 * s0, 1790*da0073e9SAndroid Build Coastguard Worker 0, 1791*da0073e9SAndroid Build Coastguard Worker ) 1792*da0073e9SAndroid Build Coastguard Worker ) 1793*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) 1794*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1795*da0073e9SAndroid Build Coastguard Worker Ne( 1796*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1797*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1798*da0073e9SAndroid Build Coastguard Worker + 1, 1799*da0073e9SAndroid Build Coastguard Worker 1, 1800*da0073e9SAndroid Build Coastguard Worker ) 1801*da0073e9SAndroid Build Coastguard Worker ) 1802*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1803*da0073e9SAndroid Build Coastguard Worker Ne( 1804*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1805*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1806*da0073e9SAndroid Build Coastguard Worker + 1, 1807*da0073e9SAndroid Build Coastguard Worker 0, 1808*da0073e9SAndroid Build Coastguard Worker ) 1809*da0073e9SAndroid Build Coastguard Worker ) 1810*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1811*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1812*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1813*da0073e9SAndroid Build Coastguard Worker + 1 1814*da0073e9SAndroid Build Coastguard Worker >= 0 1815*da0073e9SAndroid Build Coastguard Worker ) 1816*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0)) 1817*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1818*da0073e9SAndroid Build Coastguard Worker 1 1819*da0073e9SAndroid Build Coastguard Worker < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1820*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1821*da0073e9SAndroid Build Coastguard Worker + 60 1822*da0073e9SAndroid Build Coastguard Worker ) 1823*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1)) 1824*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1825*da0073e9SAndroid Build Coastguard Worker Ne( 1826*da0073e9SAndroid Build Coastguard Worker 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1827*da0073e9SAndroid Build Coastguard Worker - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1828*da0073e9SAndroid Build Coastguard Worker + 60 * s0, 1829*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1830*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1831*da0073e9SAndroid Build Coastguard Worker + 120, 1832*da0073e9SAndroid Build Coastguard Worker ) 1833*da0073e9SAndroid Build Coastguard Worker ) 1834*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1835*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1836*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1837*da0073e9SAndroid Build Coastguard Worker + 120 1838*da0073e9SAndroid Build Coastguard Worker > 0 1839*da0073e9SAndroid Build Coastguard Worker ) 1840*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1841*da0073e9SAndroid Build Coastguard Worker Eq( 1842*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2)) 1843*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2) 1844*da0073e9SAndroid Build Coastguard Worker + 60 * (Mod(s0, 2)), 1845*da0073e9SAndroid Build Coastguard Worker 0, 1846*da0073e9SAndroid Build Coastguard Worker ) 1847*da0073e9SAndroid Build Coastguard Worker ) 1848*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1849*da0073e9SAndroid Build Coastguard Worker Ne( 1850*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1851*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1852*da0073e9SAndroid Build Coastguard Worker + 120, 1853*da0073e9SAndroid Build Coastguard Worker 0, 1854*da0073e9SAndroid Build Coastguard Worker ) 1855*da0073e9SAndroid Build Coastguard Worker ) 1856*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1857*da0073e9SAndroid Build Coastguard Worker Ne( 1858*da0073e9SAndroid Build Coastguard Worker 60 1859*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, 2)) 1860*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1861*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1862*da0073e9SAndroid Build Coastguard Worker - 120 1863*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, 2) 1864*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, (FloorDiv(s0, 2))) 1865*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1866*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1867*da0073e9SAndroid Build Coastguard Worker 0, 1868*da0073e9SAndroid Build Coastguard Worker ) 1869*da0073e9SAndroid Build Coastguard Worker ) 1870*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) 1871*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1872*da0073e9SAndroid Build Coastguard Worker Ne( 1873*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1874*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1875*da0073e9SAndroid Build Coastguard Worker + 60, 1876*da0073e9SAndroid Build Coastguard Worker 0, 1877*da0073e9SAndroid Build Coastguard Worker ) 1878*da0073e9SAndroid Build Coastguard Worker ) 1879*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1880*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1881*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1882*da0073e9SAndroid Build Coastguard Worker + 60 1883*da0073e9SAndroid Build Coastguard Worker >= 0 1884*da0073e9SAndroid Build Coastguard Worker ) 1885*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1886*da0073e9SAndroid Build Coastguard Worker 1 1887*da0073e9SAndroid Build Coastguard Worker < 60 1888*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1889*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1890*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1891*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1892*da0073e9SAndroid Build Coastguard Worker ) 1893*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(16 * s0, 32)) 1894*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) 1895*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(16 * s0, 32)) 1896*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) 1897*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv(s0, 2) >= 2) 1898*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) 1899*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(1 < FloorDiv(s0, 2)) 1900*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(s0, 2)) 1901*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1902*da0073e9SAndroid Build Coastguard Worker 60 1903*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1904*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1905*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1906*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1907*da0073e9SAndroid Build Coastguard Worker >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1908*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1909*da0073e9SAndroid Build Coastguard Worker + 60 1910*da0073e9SAndroid Build Coastguard Worker ) 1911*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1912*da0073e9SAndroid Build Coastguard Worker 60 1913*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, 2)) 1914*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1915*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1916*da0073e9SAndroid Build Coastguard Worker - 120 1917*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, 2) 1918*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, (FloorDiv(s0, 2))) 1919*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1920*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1921*da0073e9SAndroid Build Coastguard Worker > 0 1922*da0073e9SAndroid Build Coastguard Worker ) 1923*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1924*da0073e9SAndroid Build Coastguard Worker Ne( 1925*da0073e9SAndroid Build Coastguard Worker 60 1926*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, 2)) 1927*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1928*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1929*da0073e9SAndroid Build Coastguard Worker - 120 1930*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, 2) 1931*da0073e9SAndroid Build Coastguard Worker * FloorDiv(s0, (FloorDiv(s0, 2))) 1932*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1933*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1934*da0073e9SAndroid Build Coastguard Worker 3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1935*da0073e9SAndroid Build Coastguard Worker ) 1936*da0073e9SAndroid Build Coastguard Worker ) 1937*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1938*da0073e9SAndroid Build Coastguard Worker Ne( 1939*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1940*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1941*da0073e9SAndroid Build Coastguard Worker + 20, 1942*da0073e9SAndroid Build Coastguard Worker 0, 1943*da0073e9SAndroid Build Coastguard Worker ) 1944*da0073e9SAndroid Build Coastguard Worker ) 1945*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1946*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1947*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1948*da0073e9SAndroid Build Coastguard Worker + 20 1949*da0073e9SAndroid Build Coastguard Worker >= 0 1950*da0073e9SAndroid Build Coastguard Worker ) 1951*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1952*da0073e9SAndroid Build Coastguard Worker Ne( 1953*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1954*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1955*da0073e9SAndroid Build Coastguard Worker + 20, 1956*da0073e9SAndroid Build Coastguard Worker 20, 1957*da0073e9SAndroid Build Coastguard Worker ) 1958*da0073e9SAndroid Build Coastguard Worker ) 1959*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1960*da0073e9SAndroid Build Coastguard Worker Ne( 1961*da0073e9SAndroid Build Coastguard Worker 20 1962*da0073e9SAndroid Build Coastguard Worker * ( 1963*da0073e9SAndroid Build Coastguard Worker Mod( 1964*da0073e9SAndroid Build Coastguard Worker 1, 1965*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1966*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1967*da0073e9SAndroid Build Coastguard Worker + 1, 1968*da0073e9SAndroid Build Coastguard Worker ) 1969*da0073e9SAndroid Build Coastguard Worker ), 1970*da0073e9SAndroid Build Coastguard Worker 0, 1971*da0073e9SAndroid Build Coastguard Worker ) 1972*da0073e9SAndroid Build Coastguard Worker ) 1973*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 1974*da0073e9SAndroid Build Coastguard Worker Ne( 1975*da0073e9SAndroid Build Coastguard Worker 20 1976*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1977*da0073e9SAndroid Build Coastguard Worker * ( 1978*da0073e9SAndroid Build Coastguard Worker Mod( 1979*da0073e9SAndroid Build Coastguard Worker 1, 1980*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1981*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1982*da0073e9SAndroid Build Coastguard Worker - 2 1983*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1984*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1985*da0073e9SAndroid Build Coastguard Worker + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1986*da0073e9SAndroid Build Coastguard Worker ) 1987*da0073e9SAndroid Build Coastguard Worker ) 1988*da0073e9SAndroid Build Coastguard Worker - 20 1989*da0073e9SAndroid Build Coastguard Worker * Mod( 1990*da0073e9SAndroid Build Coastguard Worker 1, 1991*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1992*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1993*da0073e9SAndroid Build Coastguard Worker - 2 1994*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1995*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1996*da0073e9SAndroid Build Coastguard Worker + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1997*da0073e9SAndroid Build Coastguard Worker ), 1998*da0073e9SAndroid Build Coastguard Worker 0, 1999*da0073e9SAndroid Build Coastguard Worker ) 2000*da0073e9SAndroid Build Coastguard Worker ) 2001*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) 2002*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2003*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2004*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2005*da0073e9SAndroid Build Coastguard Worker + 1 2006*da0073e9SAndroid Build Coastguard Worker >= 1 2007*da0073e9SAndroid Build Coastguard Worker ) 2008*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2009*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2010*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2011*da0073e9SAndroid Build Coastguard Worker + 20 2012*da0073e9SAndroid Build Coastguard Worker >= 0 2013*da0073e9SAndroid Build Coastguard Worker ) 2014*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2015*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2016*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2017*da0073e9SAndroid Build Coastguard Worker + 20 2018*da0073e9SAndroid Build Coastguard Worker >= 1 2019*da0073e9SAndroid Build Coastguard Worker ) 2020*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2021*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2022*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2023*da0073e9SAndroid Build Coastguard Worker + 20 2024*da0073e9SAndroid Build Coastguard Worker >= 2 2025*da0073e9SAndroid Build Coastguard Worker ) 2026*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2027*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2028*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2029*da0073e9SAndroid Build Coastguard Worker + 20 2030*da0073e9SAndroid Build Coastguard Worker > 1 2031*da0073e9SAndroid Build Coastguard Worker ) 2032*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2033*da0073e9SAndroid Build Coastguard Worker 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2034*da0073e9SAndroid Build Coastguard Worker - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2035*da0073e9SAndroid Build Coastguard Worker + 20 2036*da0073e9SAndroid Build Coastguard Worker < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2037*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2038*da0073e9SAndroid Build Coastguard Worker + 60 2039*da0073e9SAndroid Build Coastguard Worker ) 2040*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2041*da0073e9SAndroid Build Coastguard Worker Ne( 2042*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2043*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2044*da0073e9SAndroid Build Coastguard Worker + 60, 2045*da0073e9SAndroid Build Coastguard Worker 60, 2046*da0073e9SAndroid Build Coastguard Worker ) 2047*da0073e9SAndroid Build Coastguard Worker ) 2048*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2049*da0073e9SAndroid Build Coastguard Worker Ne( 2050*da0073e9SAndroid Build Coastguard Worker FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 2051*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2052*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2053*da0073e9SAndroid Build Coastguard Worker + 1, 2054*da0073e9SAndroid Build Coastguard Worker ) 2055*da0073e9SAndroid Build Coastguard Worker ) 2056*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2057*da0073e9SAndroid Build Coastguard Worker Eq( 2058*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 2059*da0073e9SAndroid Build Coastguard Worker * ( 2060*da0073e9SAndroid Build Coastguard Worker Mod( 2061*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2062*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2063*da0073e9SAndroid Build Coastguard Worker - 2 2064*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2065*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2066*da0073e9SAndroid Build Coastguard Worker + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 2067*da0073e9SAndroid Build Coastguard Worker 1, 2068*da0073e9SAndroid Build Coastguard Worker ) 2069*da0073e9SAndroid Build Coastguard Worker ) 2070*da0073e9SAndroid Build Coastguard Worker - Mod( 2071*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2072*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2073*da0073e9SAndroid Build Coastguard Worker - 2 2074*da0073e9SAndroid Build Coastguard Worker * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2075*da0073e9SAndroid Build Coastguard Worker / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2076*da0073e9SAndroid Build Coastguard Worker + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 2077*da0073e9SAndroid Build Coastguard Worker 1, 2078*da0073e9SAndroid Build Coastguard Worker ), 2079*da0073e9SAndroid Build Coastguard Worker 0, 2080*da0073e9SAndroid Build Coastguard Worker ) 2081*da0073e9SAndroid Build Coastguard Worker ) 2082*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2083*da0073e9SAndroid Build Coastguard Worker Ne( 2084*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2085*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2086*da0073e9SAndroid Build Coastguard Worker + 1, 2087*da0073e9SAndroid Build Coastguard Worker FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 2088*da0073e9SAndroid Build Coastguard Worker ) 2089*da0073e9SAndroid Build Coastguard Worker ) 2090*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(8 * s0, 16)) 2091*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2092*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2093*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2094*da0073e9SAndroid Build Coastguard Worker + 60 2095*da0073e9SAndroid Build Coastguard Worker >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2096*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2097*da0073e9SAndroid Build Coastguard Worker + 1 2098*da0073e9SAndroid Build Coastguard Worker ) 2099*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2100*da0073e9SAndroid Build Coastguard Worker 60 2101*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2102*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2103*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2104*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2105*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2106*da0073e9SAndroid Build Coastguard Worker ) 2107*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2108*da0073e9SAndroid Build Coastguard Worker 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2109*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2110*da0073e9SAndroid Build Coastguard Worker + 90 2111*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2112*da0073e9SAndroid Build Coastguard Worker ) 2113*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv(s0, 2) < 16) 2114*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv(s0, 2) > 1) 2115*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2116*da0073e9SAndroid Build Coastguard Worker Ne( 2117*da0073e9SAndroid Build Coastguard Worker 90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2118*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2119*da0073e9SAndroid Build Coastguard Worker + 90 * (FloorDiv(s0, 2)), 2120*da0073e9SAndroid Build Coastguard Worker 0, 2121*da0073e9SAndroid Build Coastguard Worker ) 2122*da0073e9SAndroid Build Coastguard Worker ) 2123*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2124*da0073e9SAndroid Build Coastguard Worker 1 2125*da0073e9SAndroid Build Coastguard Worker < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2126*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2127*da0073e9SAndroid Build Coastguard Worker + 90 2128*da0073e9SAndroid Build Coastguard Worker ) 2129*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2130*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2131*da0073e9SAndroid Build Coastguard Worker - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2132*da0073e9SAndroid Build Coastguard Worker + 1 2133*da0073e9SAndroid Build Coastguard Worker > 1 2134*da0073e9SAndroid Build Coastguard Worker ) 2135*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2136*da0073e9SAndroid Build Coastguard Worker 60 2137*da0073e9SAndroid Build Coastguard Worker * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2138*da0073e9SAndroid Build Coastguard Worker * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2139*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2140*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2141*da0073e9SAndroid Build Coastguard Worker > 1 2142*da0073e9SAndroid Build Coastguard Worker ) 2143*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2144*da0073e9SAndroid Build Coastguard Worker Ne( 2145*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2146*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2147*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)), 2148*da0073e9SAndroid Build Coastguard Worker 0, 2149*da0073e9SAndroid Build Coastguard Worker ) 2150*da0073e9SAndroid Build Coastguard Worker ) 2151*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2152*da0073e9SAndroid Build Coastguard Worker 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2153*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2154*da0073e9SAndroid Build Coastguard Worker + 90 2155*da0073e9SAndroid Build Coastguard Worker > 1 2156*da0073e9SAndroid Build Coastguard Worker ) 2157*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2158*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2159*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2160*da0073e9SAndroid Build Coastguard Worker + 60 2161*da0073e9SAndroid Build Coastguard Worker > 1 2162*da0073e9SAndroid Build Coastguard Worker ) 2163*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2164*da0073e9SAndroid Build Coastguard Worker Ne( 2165*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2166*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2167*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)), 2168*da0073e9SAndroid Build Coastguard Worker 3 * (FloorDiv(s0, 2)), 2169*da0073e9SAndroid Build Coastguard Worker ) 2170*da0073e9SAndroid Build Coastguard Worker ) 2171*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2172*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2173*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2174*da0073e9SAndroid Build Coastguard Worker + 60 * (FloorDiv(s0, 2)) 2175*da0073e9SAndroid Build Coastguard Worker > 0 2176*da0073e9SAndroid Build Coastguard Worker ) 2177*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2178*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2179*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2180*da0073e9SAndroid Build Coastguard Worker + 60 2181*da0073e9SAndroid Build Coastguard Worker > 0 2182*da0073e9SAndroid Build Coastguard Worker ) 2183*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2184*da0073e9SAndroid Build Coastguard Worker Ne( 2185*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2186*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2187*da0073e9SAndroid Build Coastguard Worker + 120, 2188*da0073e9SAndroid Build Coastguard Worker 0, 2189*da0073e9SAndroid Build Coastguard Worker ) 2190*da0073e9SAndroid Build Coastguard Worker ) 2191*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2192*da0073e9SAndroid Build Coastguard Worker 1 2193*da0073e9SAndroid Build Coastguard Worker < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2194*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2195*da0073e9SAndroid Build Coastguard Worker + 120 2196*da0073e9SAndroid Build Coastguard Worker ) 2197*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2198*da0073e9SAndroid Build Coastguard Worker Ne( 2199*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2200*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2201*da0073e9SAndroid Build Coastguard Worker + 120, 2202*da0073e9SAndroid Build Coastguard Worker 6, 2203*da0073e9SAndroid Build Coastguard Worker ) 2204*da0073e9SAndroid Build Coastguard Worker ) 2205*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2206*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2207*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2208*da0073e9SAndroid Build Coastguard Worker + 120 2209*da0073e9SAndroid Build Coastguard Worker > 0 2210*da0073e9SAndroid Build Coastguard Worker ) 2211*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2212*da0073e9SAndroid Build Coastguard Worker Ne( 2213*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2214*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2215*da0073e9SAndroid Build Coastguard Worker + 120, 2216*da0073e9SAndroid Build Coastguard Worker 0, 2217*da0073e9SAndroid Build Coastguard Worker ) 2218*da0073e9SAndroid Build Coastguard Worker ) 2219*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2220*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2221*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2222*da0073e9SAndroid Build Coastguard Worker + 120 2223*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2224*da0073e9SAndroid Build Coastguard Worker ) 2225*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2226*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2227*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2228*da0073e9SAndroid Build Coastguard Worker + 120 2229*da0073e9SAndroid Build Coastguard Worker <= 20480 2230*da0073e9SAndroid Build Coastguard Worker ) 2231*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2232*da0073e9SAndroid Build Coastguard Worker Ne( 2233*da0073e9SAndroid Build Coastguard Worker 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2234*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2235*da0073e9SAndroid Build Coastguard Worker + 90, 2236*da0073e9SAndroid Build Coastguard Worker 0, 2237*da0073e9SAndroid Build Coastguard Worker ) 2238*da0073e9SAndroid Build Coastguard Worker ) 2239*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2240*da0073e9SAndroid Build Coastguard Worker 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2241*da0073e9SAndroid Build Coastguard Worker - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2242*da0073e9SAndroid Build Coastguard Worker + 120 2243*da0073e9SAndroid Build Coastguard Worker > 1 2244*da0073e9SAndroid Build Coastguard Worker ) 2245*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2246*da0073e9SAndroid Build Coastguard Worker 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2247*da0073e9SAndroid Build Coastguard Worker - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2248*da0073e9SAndroid Build Coastguard Worker + 90 2249*da0073e9SAndroid Build Coastguard Worker <= 20480 2250*da0073e9SAndroid Build Coastguard Worker ) 2251*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2252*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2253*da0073e9SAndroid Build Coastguard Worker - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2254*da0073e9SAndroid Build Coastguard Worker + 60 2255*da0073e9SAndroid Build Coastguard Worker <= 20480 2256*da0073e9SAndroid Build Coastguard Worker ) 2257*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2258*da0073e9SAndroid Build Coastguard Worker Ne( 2259*da0073e9SAndroid Build Coastguard Worker 240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2260*da0073e9SAndroid Build Coastguard Worker - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2261*da0073e9SAndroid Build Coastguard Worker + 240, 2262*da0073e9SAndroid Build Coastguard Worker 0, 2263*da0073e9SAndroid Build Coastguard Worker ) 2264*da0073e9SAndroid Build Coastguard Worker ) 2265*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(6 * s5, 132)) 2266*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(4, FloorDiv(s0, 2))) 2267*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4)) 2268*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2269*da0073e9SAndroid Build Coastguard Worker Ne( 2270*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2271*da0073e9SAndroid Build Coastguard Worker - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2272*da0073e9SAndroid Build Coastguard Worker + 64 * (FloorDiv(s0, 2)), 2273*da0073e9SAndroid Build Coastguard Worker 0, 2274*da0073e9SAndroid Build Coastguard Worker ) 2275*da0073e9SAndroid Build Coastguard Worker ) 2276*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2277*da0073e9SAndroid Build Coastguard Worker 1 2278*da0073e9SAndroid Build Coastguard Worker < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2279*da0073e9SAndroid Build Coastguard Worker - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2280*da0073e9SAndroid Build Coastguard Worker + 64 2281*da0073e9SAndroid Build Coastguard Worker ) 2282*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2283*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2284*da0073e9SAndroid Build Coastguard Worker - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2285*da0073e9SAndroid Build Coastguard Worker + 64 2286*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2287*da0073e9SAndroid Build Coastguard Worker ) 2288*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2289*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2290*da0073e9SAndroid Build Coastguard Worker - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2291*da0073e9SAndroid Build Coastguard Worker + 64 2292*da0073e9SAndroid Build Coastguard Worker > 1 2293*da0073e9SAndroid Build Coastguard Worker ) 2294*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2295*da0073e9SAndroid Build Coastguard Worker 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2296*da0073e9SAndroid Build Coastguard Worker - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2297*da0073e9SAndroid Build Coastguard Worker + 62 2298*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2299*da0073e9SAndroid Build Coastguard Worker ) 2300*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2301*da0073e9SAndroid Build Coastguard Worker Ne( 2302*da0073e9SAndroid Build Coastguard Worker 62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2303*da0073e9SAndroid Build Coastguard Worker - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2304*da0073e9SAndroid Build Coastguard Worker + 62 * (FloorDiv(s0, 2)), 2305*da0073e9SAndroid Build Coastguard Worker 0, 2306*da0073e9SAndroid Build Coastguard Worker ) 2307*da0073e9SAndroid Build Coastguard Worker ) 2308*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2309*da0073e9SAndroid Build Coastguard Worker 1 2310*da0073e9SAndroid Build Coastguard Worker < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2311*da0073e9SAndroid Build Coastguard Worker - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2312*da0073e9SAndroid Build Coastguard Worker + 62 2313*da0073e9SAndroid Build Coastguard Worker ) 2314*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) 2315*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) 2316*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) 2317*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(4, FloorDiv(s0, 2))) 2318*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) 2319*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3) 2320*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2321*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2322*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2323*da0073e9SAndroid Build Coastguard Worker + 576 2324*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2325*da0073e9SAndroid Build Coastguard Worker ) 2326*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0) 2327*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1) 2328*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2329*da0073e9SAndroid Build Coastguard Worker Ne( 2330*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2331*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2332*da0073e9SAndroid Build Coastguard Worker + 576 * (FloorDiv(s0, 2)), 2333*da0073e9SAndroid Build Coastguard Worker 0, 2334*da0073e9SAndroid Build Coastguard Worker ) 2335*da0073e9SAndroid Build Coastguard Worker ) 2336*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) 2337*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2338*da0073e9SAndroid Build Coastguard Worker Ne( 2339*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2340*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2341*da0073e9SAndroid Build Coastguard Worker + 9, 2342*da0073e9SAndroid Build Coastguard Worker 1, 2343*da0073e9SAndroid Build Coastguard Worker ) 2344*da0073e9SAndroid Build Coastguard Worker ) 2345*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2346*da0073e9SAndroid Build Coastguard Worker Ne( 2347*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2348*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2349*da0073e9SAndroid Build Coastguard Worker + 9, 2350*da0073e9SAndroid Build Coastguard Worker 0, 2351*da0073e9SAndroid Build Coastguard Worker ) 2352*da0073e9SAndroid Build Coastguard Worker ) 2353*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2354*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2355*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2356*da0073e9SAndroid Build Coastguard Worker + 9 2357*da0073e9SAndroid Build Coastguard Worker >= 0 2358*da0073e9SAndroid Build Coastguard Worker ) 2359*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0)) 2360*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2361*da0073e9SAndroid Build Coastguard Worker 1 2362*da0073e9SAndroid Build Coastguard Worker < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2363*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2364*da0073e9SAndroid Build Coastguard Worker + 576 2365*da0073e9SAndroid Build Coastguard Worker ) 2366*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) 2367*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2368*da0073e9SAndroid Build Coastguard Worker Ne( 2369*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2370*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2371*da0073e9SAndroid Build Coastguard Worker + 576 * (FloorDiv(s0, 2)), 2372*da0073e9SAndroid Build Coastguard Worker 256, 2373*da0073e9SAndroid Build Coastguard Worker ) 2374*da0073e9SAndroid Build Coastguard Worker ) 2375*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2376*da0073e9SAndroid Build Coastguard Worker Eq( 2377*da0073e9SAndroid Build Coastguard Worker 64 2378*da0073e9SAndroid Build Coastguard Worker * ( 2379*da0073e9SAndroid Build Coastguard Worker Mod( 2380*da0073e9SAndroid Build Coastguard Worker (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2381*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2382*da0073e9SAndroid Build Coastguard Worker + 9 * (FloorDiv(s0, 2)), 2383*da0073e9SAndroid Build Coastguard Worker 4, 2384*da0073e9SAndroid Build Coastguard Worker ) 2385*da0073e9SAndroid Build Coastguard Worker ), 2386*da0073e9SAndroid Build Coastguard Worker 0, 2387*da0073e9SAndroid Build Coastguard Worker ) 2388*da0073e9SAndroid Build Coastguard Worker ) 2389*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2390*da0073e9SAndroid Build Coastguard Worker Eq( 2391*da0073e9SAndroid Build Coastguard Worker FloorDiv(s0, 2), 2392*da0073e9SAndroid Build Coastguard Worker FloorDiv( 2393*da0073e9SAndroid Build Coastguard Worker ( 2394*da0073e9SAndroid Build Coastguard Worker (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2395*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2396*da0073e9SAndroid Build Coastguard Worker + 9 * (FloorDiv(s0, 2)) 2397*da0073e9SAndroid Build Coastguard Worker ), 2398*da0073e9SAndroid Build Coastguard Worker 4, 2399*da0073e9SAndroid Build Coastguard Worker ), 2400*da0073e9SAndroid Build Coastguard Worker ) 2401*da0073e9SAndroid Build Coastguard Worker ) 2402*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2403*da0073e9SAndroid Build Coastguard Worker Eq( 2404*da0073e9SAndroid Build Coastguard Worker FloorDiv( 2405*da0073e9SAndroid Build Coastguard Worker ( 2406*da0073e9SAndroid Build Coastguard Worker (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2407*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2408*da0073e9SAndroid Build Coastguard Worker + 9 * (FloorDiv(s0, 2)) 2409*da0073e9SAndroid Build Coastguard Worker ), 2410*da0073e9SAndroid Build Coastguard Worker 4, 2411*da0073e9SAndroid Build Coastguard Worker ), 2412*da0073e9SAndroid Build Coastguard Worker FloorDiv(s0, 2), 2413*da0073e9SAndroid Build Coastguard Worker ) 2414*da0073e9SAndroid Build Coastguard Worker ) 2415*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2416*da0073e9SAndroid Build Coastguard Worker Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0) 2417*da0073e9SAndroid Build Coastguard Worker ) 2418*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2419*da0073e9SAndroid Build Coastguard Worker Eq( 2420*da0073e9SAndroid Build Coastguard Worker 64 2421*da0073e9SAndroid Build Coastguard Worker * ( 2422*da0073e9SAndroid Build Coastguard Worker Mod( 2423*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2424*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2425*da0073e9SAndroid Build Coastguard Worker + 1, 2426*da0073e9SAndroid Build Coastguard Worker 4, 2427*da0073e9SAndroid Build Coastguard Worker ) 2428*da0073e9SAndroid Build Coastguard Worker ), 2429*da0073e9SAndroid Build Coastguard Worker 0, 2430*da0073e9SAndroid Build Coastguard Worker ) 2431*da0073e9SAndroid Build Coastguard Worker ) 2432*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2433*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2434*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2435*da0073e9SAndroid Build Coastguard Worker + 576 * (FloorDiv(s0, 2)) 2436*da0073e9SAndroid Build Coastguard Worker > 0 2437*da0073e9SAndroid Build Coastguard Worker ) 2438*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2439*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2440*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2441*da0073e9SAndroid Build Coastguard Worker + 9 2442*da0073e9SAndroid Build Coastguard Worker >= 1 2443*da0073e9SAndroid Build Coastguard Worker ) 2444*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2445*da0073e9SAndroid Build Coastguard Worker Eq( 2446*da0073e9SAndroid Build Coastguard Worker 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2447*da0073e9SAndroid Build Coastguard Worker - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2448*da0073e9SAndroid Build Coastguard Worker + 576, 2449*da0073e9SAndroid Build Coastguard Worker 256, 2450*da0073e9SAndroid Build Coastguard Worker ) 2451*da0073e9SAndroid Build Coastguard Worker ) 2452*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2453*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2454*da0073e9SAndroid Build Coastguard Worker - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2455*da0073e9SAndroid Build Coastguard Worker + 540 2456*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2457*da0073e9SAndroid Build Coastguard Worker ) 2458*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2459*da0073e9SAndroid Build Coastguard Worker Ne( 2460*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2461*da0073e9SAndroid Build Coastguard Worker - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2462*da0073e9SAndroid Build Coastguard Worker + 540 * (FloorDiv(s0, 2)), 2463*da0073e9SAndroid Build Coastguard Worker 0, 2464*da0073e9SAndroid Build Coastguard Worker ) 2465*da0073e9SAndroid Build Coastguard Worker ) 2466*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2467*da0073e9SAndroid Build Coastguard Worker 1 2468*da0073e9SAndroid Build Coastguard Worker < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2469*da0073e9SAndroid Build Coastguard Worker - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2470*da0073e9SAndroid Build Coastguard Worker + 540 2471*da0073e9SAndroid Build Coastguard Worker ) 2472*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2473*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2474*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2475*da0073e9SAndroid Build Coastguard Worker + 9 2476*da0073e9SAndroid Build Coastguard Worker <= 2147483647 2477*da0073e9SAndroid Build Coastguard Worker ) 2478*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2479*da0073e9SAndroid Build Coastguard Worker Ne( 2480*da0073e9SAndroid Build Coastguard Worker (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2481*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2482*da0073e9SAndroid Build Coastguard Worker + 9 * (FloorDiv(s0, 2)), 2483*da0073e9SAndroid Build Coastguard Worker 0, 2484*da0073e9SAndroid Build Coastguard Worker ) 2485*da0073e9SAndroid Build Coastguard Worker ) 2486*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2487*da0073e9SAndroid Build Coastguard Worker 1 2488*da0073e9SAndroid Build Coastguard Worker < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2489*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2490*da0073e9SAndroid Build Coastguard Worker + 9 2491*da0073e9SAndroid Build Coastguard Worker ) 2492*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2493*da0073e9SAndroid Build Coastguard Worker (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2494*da0073e9SAndroid Build Coastguard Worker - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2495*da0073e9SAndroid Build Coastguard Worker + 9 2496*da0073e9SAndroid Build Coastguard Worker > 1 2497*da0073e9SAndroid Build Coastguard Worker ) 2498*da0073e9SAndroid Build Coastguard Worker dim_constraints.add( 2499*da0073e9SAndroid Build Coastguard Worker 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2500*da0073e9SAndroid Build Coastguard Worker - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2501*da0073e9SAndroid Build Coastguard Worker + 540 2502*da0073e9SAndroid Build Coastguard Worker > 1 2503*da0073e9SAndroid Build Coastguard Worker ) 2504*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s0 >= 2) 2505*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s1 >= 2) 2506*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s6 >= 2) 2507*da0073e9SAndroid Build Coastguard Worker dim_constraints.add(s5 >= 2) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker dim_constraints.solve() 2510*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2511*da0073e9SAndroid Build Coastguard Worker dim_constraints._static_results, 2512*da0073e9SAndroid Build Coastguard Worker { 2513*da0073e9SAndroid Build Coastguard Worker "L['c'].size()[0] == 8", 2514*da0073e9SAndroid Build Coastguard Worker "L['d'].size()[0] == 8", 2515*da0073e9SAndroid Build Coastguard Worker "L['a'].size()[2] == 96", 2516*da0073e9SAndroid Build Coastguard Worker "L['f'].size()[1] == 1", 2517*da0073e9SAndroid Build Coastguard Worker "L['a'].size()[3] == 96", 2518*da0073e9SAndroid Build Coastguard Worker "L['b'].size()[2] == 3", 2519*da0073e9SAndroid Build Coastguard Worker "L['b'].size()[1] == 22", 2520*da0073e9SAndroid Build Coastguard Worker "L['b'].size()[0] == 8", 2521*da0073e9SAndroid Build Coastguard Worker "L['a'].size()[1] == 22", 2522*da0073e9SAndroid Build Coastguard Worker "L['a'].size()[0] == 8", 2523*da0073e9SAndroid Build Coastguard Worker }, 2524*da0073e9SAndroid Build Coastguard Worker ) 2525*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2526*da0073e9SAndroid Build Coastguard Worker dim_constraints._dynamic_results, 2527*da0073e9SAndroid Build Coastguard Worker { 2528*da0073e9SAndroid Build Coastguard Worker "2 <= L['c'].size()[1]", 2529*da0073e9SAndroid Build Coastguard Worker "L['d'].size()[1] == L['c'].size()[1]", 2530*da0073e9SAndroid Build Coastguard Worker "L['e'].size()[1] == L['c'].size()[1]", 2531*da0073e9SAndroid Build Coastguard Worker }, 2532*da0073e9SAndroid Build Coastguard Worker ) 2533*da0073e9SAndroid Build Coastguard Worker 2534*da0073e9SAndroid Build Coastguard Worker 2535*da0073e9SAndroid Build Coastguard Workerclass TestGuardsExpressions(TestCase): 2536*da0073e9SAndroid Build Coastguard Worker """ 2537*da0073e9SAndroid Build Coastguard Worker Tests the guards-related methods used by the inductor FX graph cache. 2538*da0073e9SAndroid Build Coastguard Worker """ 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker def test_guards_gt_lt(self): 2541*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 2542*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 6) 2543*da0073e9SAndroid Build Coastguard Worker s1 = create_symint(shape_env, 7) 2544*da0073e9SAndroid Build Coastguard Worker s2 = create_symint(shape_env, 5) 2545*da0073e9SAndroid Build Coastguard Worker 2546*da0073e9SAndroid Build Coastguard Worker guard_int(sym_int(s0 > 5)) 2547*da0073e9SAndroid Build Coastguard Worker guard_int(sym_int(s0 < 7)) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker guards = shape_env.produce_guards_expression([s0]) 2550*da0073e9SAndroid Build Coastguard Worker 2551*da0073e9SAndroid Build Coastguard Worker self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2552*da0073e9SAndroid Build Coastguard Worker self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) 2553*da0073e9SAndroid Build Coastguard Worker self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)])) 2554*da0073e9SAndroid Build Coastguard Worker 2555*da0073e9SAndroid Build Coastguard Worker def test_guards_float_print(self): 2556*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 2557*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 3) 2558*da0073e9SAndroid Build Coastguard Worker guard_bool(2 / s0 == 2 / 3) 2559*da0073e9SAndroid Build Coastguard Worker guards = shape_env.produce_guards_expression([s0]) 2560*da0073e9SAndroid Build Coastguard Worker self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2561*da0073e9SAndroid Build Coastguard Worker 2562*da0073e9SAndroid Build Coastguard Worker def test_guards_float_div(self): 2563*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 2564*da0073e9SAndroid Build Coastguard Worker s0 = create_symint(shape_env, 8) 2565*da0073e9SAndroid Build Coastguard Worker s1 = create_symint(shape_env, 7) 2566*da0073e9SAndroid Build Coastguard Worker 2567*da0073e9SAndroid Build Coastguard Worker guard_int(sym_int(s0 / 2.0)) 2568*da0073e9SAndroid Build Coastguard Worker guards = shape_env.produce_guards_expression([s0]) 2569*da0073e9SAndroid Build Coastguard Worker 2570*da0073e9SAndroid Build Coastguard Worker self.assertIn("ToFloat", guards) 2571*da0073e9SAndroid Build Coastguard Worker self.assertIn("FloatTrueDiv", guards) 2572*da0073e9SAndroid Build Coastguard Worker self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2573*da0073e9SAndroid Build Coastguard Worker self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) 2574*da0073e9SAndroid Build Coastguard Worker 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2577*da0073e9SAndroid Build Coastguard Worker run_tests() 2578