1# Owner(s): ["oncall: jit"] 2 3import contextlib 4import copy 5import itertools 6import math 7import operator 8import unittest 9 10import numpy as np 11import sympy 12 13import torch 14import torch.fx 15import torch.nn.functional as F 16from torch import sym_int, SymBool, SymFloat, SymInt 17from torch._C import _disabled_torch_function_impl 18from torch.fx.experimental import sym_node 19from torch.fx.experimental.proxy_tensor import make_fx 20from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node 21from torch.fx.experimental.symbolic_shapes import ( 22 _constrain_range_for_size, 23 DimConstraints, 24 DimDynamic, 25 expect_true, 26 guard_bool, 27 guard_float, 28 guard_int, 29 GuardOnDataDependentSymNode, 30 hint_int, 31 is_symbolic, 32 ShapeEnv, 33 StatelessSymbolicContext, 34 statically_known_true, 35) 36from torch.testing._internal.common_utils import ( 37 instantiate_parametrized_tests, 38 parametrize, 39 run_tests, 40 skipIfTorchDynamo, 41 TestCase, 42) 43from torch.utils import _pytree as pytree 44from torch.utils._python_dispatch import TorchDispatchMode 45from torch.utils._sympy.functions import ( 46 FloorDiv, 47 IsNonOverlappingAndDenseIndicator, 48 Mod, 49) 50 51 52aten = torch.ops.aten 53 54meta_funcs = {} 55 56 57def register_meta(op): 58 def decorator(f): 59 def add_func(op): 60 meta_funcs[op] = f 61 62 pytree.tree_map_(add_func, op) 63 return f 64 65 return decorator 66 67 68@register_meta([aten.add.Tensor, aten.sub.Tensor]) 69def binary_meta(a, b): 70 return a.new_empty(a.shape) 71 72 73@register_meta(aten.cat.default) 74def cat_meta(tensors, dim=0): 75 concat_length = 0 76 shape = tensors[0].shape 77 for tensor in tensors: 78 for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): 79 if idx == dim: 80 concat_length = concat_length + length 81 else: 82 assert length == common_length 83 new_shape = list(shape) 84 new_shape[dim] = concat_length 85 return tensors[0].new_empty(new_shape) 86 87 88@register_meta([aten.narrow_copy.default]) 89def narrow_copy_symint_meta(a, dim, start, length, **kwargs): 90 shape = [] 91 for i, x in enumerate(a.shape): 92 if i == dim: 93 shape.append(length) 94 else: 95 shape.append(x) 96 return a.new_empty(tuple(shape)) 97 98 99@register_meta([aten.expand.default]) 100def expand_symint_meta(a, size, implicit=False): 101 return a.new_empty(size) 102 103 104def create_contiguous(shape): 105 strides = [1] 106 for dim in reversed(shape[:-1]): 107 strides.append(dim * strides[-1]) 108 return list(reversed(strides)) 109 110 111class FakeSymbolicTensor(torch.Tensor): 112 @staticmethod 113 def __new__( 114 cls, 115 sym_shape, 116 sym_strides, 117 dtype, 118 layout, 119 requires_grad, 120 device, 121 storage_offset=0, 122 ): 123 # TODO: this is wrong in general 124 sym_stride = create_contiguous(sym_shape) 125 r = torch.Tensor._make_wrapper_subclass( 126 cls, 127 sym_shape, 128 sym_stride, 129 storage_offset, 130 dtype=dtype, 131 layout=layout, 132 requires_grad=requires_grad, 133 device=device, 134 ) 135 return r 136 137 __torch_function__ = _disabled_torch_function_impl 138 139 def new_empty(self, shape): 140 return FakeSymbolicTensor( 141 shape, None, self.dtype, self.layout, self.requires_grad, self.device 142 ) 143 144 @classmethod 145 def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): 146 if func_overload in meta_funcs: 147 return meta_funcs[func_overload](*args, **kwargs) 148 149 if func_overload == torch.ops.aten.new_empty.default: 150 self = args[0] 151 shape = args[1] 152 return FakeSymbolicTensor( 153 shape, 154 self.stride(), 155 self.dtype, 156 self.layout, 157 self.requires_grad, 158 self.device, 159 ) 160 161 raise RuntimeError(f"operator {func_overload} not supported") 162 163 164def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None): 165 from torch._dynamo.source import ConstantSource 166 167 if source is None: 168 source = ConstantSource(name) 169 constraint_dims = [None] * arg.dim() 170 if dynamic_dims is None: 171 dynamic_dims = [DimDynamic.DUCK] * arg.dim() 172 ( 173 sym_shapes, 174 sym_strides, 175 sym_storage_offset, 176 ) = shape_env.create_symbolic_sizes_strides_storage_offset( 177 arg, 178 source=source, 179 symbolic_context=StatelessSymbolicContext( 180 dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims 181 ), 182 ) 183 return FakeSymbolicTensor( 184 sym_shapes, 185 sym_strides, 186 arg.dtype, 187 arg.layout, 188 arg.requires_grad, 189 arg.device, 190 sym_storage_offset, 191 ) 192 193 194def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs): 195 from torch._dynamo.source import ConstantSource 196 197 symbol = shape_env.create_symbol( 198 val, 199 source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), 200 dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC, 201 constraint_dim=None, 202 **kwargs, 203 ) 204 return cls(SymNode(symbol, shape_env, pytype, hint=val)) 205 206 207# TODO: default duck to False 208def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt: 209 return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs) 210 211 212def create_symbool(shape_env, b: bool) -> SymBool: 213 return create_symtype(SymBool, bool, shape_env, b) 214 215 216def create_symfloat(shape_env, f: float) -> SymFloat: 217 return create_symtype(SymFloat, float, shape_env, f) 218 219 220@skipIfTorchDynamo( 221 "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" 222) 223class TestPySymInt(TestCase): 224 def test_arith_ops(self): 225 shape_env = ShapeEnv() 226 symints = [] 227 for i in range(2, 5): 228 symints.append((i, create_symint(shape_env, i))) 229 230 ops = [ 231 operator.add, 232 operator.sub, 233 operator.floordiv, 234 operator.mul, 235 operator.mod, 236 ] 237 238 for op in ops: 239 for args in itertools.permutations(symints, 2): 240 if not isinstance(args[0][1], int) and ( 241 (op != operator.mod or op != operator.floordiv) and args[1][0] != 0 242 ): 243 self.assertTrue( 244 op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]) 245 ) 246 247 def test_reverse_arith_ops(self): 248 shape_env = ShapeEnv() 249 250 a = create_symint(shape_env, 2) 251 self.assertTrue(5 // a == 5 // 2) 252 253 a = create_symint(shape_env, 2) 254 self.assertTrue(5 * a == 5 * 2) 255 256 def test_sympify_symint(self): 257 shape_env = ShapeEnv() 258 a = create_symint(shape_env, 2) 259 self.assertIs(sympy.sympify(a), a.node.expr) 260 b = create_symfloat(shape_env, 3.0) 261 self.assertIs(sympy.sympify(b), b.node.expr) 262 c = create_symbool(shape_env, True) 263 self.assertIs(sympy.sympify(c), c.node.expr) 264 265 def test_roundtrip(self): 266 shape_env = ShapeEnv() 267 x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 268 269 self.assertTrue(not isinstance(x.shape[0], SymNode)) 270 self.assertTrue(isinstance(x.shape[0], SymInt)) 271 272 self.assertTrue(x.shape[0] == 5) 273 self.assertTrue(x.shape[1] == 4) 274 self.assertTrue(x.shape[2], 3) 275 276 self.assertTrue(x.size()[0], 5) 277 self.assertTrue(x.size()[1], 4) 278 # Should be simplifiable to an integer. 279 # Ref: https://github.com/pytorch/pytorch/pull/107492 280 self.assertTrue(isinstance(x.size()[1], SymInt)) 281 self.assertTrue( 282 isinstance(x.size()[1].node.maybe_as_int(), int) 283 ) # due to guard above 284 self.assertTrue(x.size()[2] == 3) 285 286 self.assertTrue(x.size(0) == 5) 287 self.assertTrue(x.size(1) == 4) 288 self.assertTrue(x.size(2) == 3) 289 self.assertTrue(isinstance(x.size(2), SymInt)) 290 self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int)) 291 292 y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env) 293 self.assertTrue(isinstance(y.storage_offset(), SymInt)) 294 self.assertTrue(y.storage_offset() == 12) 295 296 def test_binary(self): 297 shape_env = ShapeEnv() 298 x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 299 y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) 300 301 z = x + y 302 self.assertTrue(z.shape[0] == 5) 303 self.assertTrue(z.shape[1] == 4) 304 self.assertTrue(z.shape[2] == 3) 305 306 # broadcasting 307 y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env) 308 z = x + y 309 self.assertTrue(z.shape[0] == 5) 310 self.assertTrue(z.shape[1] == 4) 311 self.assertTrue(z.shape[2] == 3) 312 313 def test_symint_args(self): 314 shape_env = ShapeEnv() 315 x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 316 y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) 317 LAST_DIM = 2 318 z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) 319 self.assertTrue(z.shape[2] == y.shape[2]) 320 321 # arithmetic expr with two symints 322 z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) 323 self.assertTrue(z.shape[2] == 2) 324 325 # arithmetic expr with a symint and python int 326 z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) 327 self.assertTrue(z.shape[2] == 2) 328 329 def test_symint_vargs(self): 330 shape_env = ShapeEnv() 331 x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) 332 y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) 333 334 # varargs 335 z = y.expand(x.shape[0], y.shape[1], x.shape[2]) 336 self.assertTrue(z.shape[0] == 5) 337 self.assertTrue(z.shape[1] == 4) 338 self.assertTrue(z.shape[2] == 3) 339 340 # shape list 341 z = y.expand((x.shape[0], y.shape[1], x.shape[2])) 342 self.assertTrue(z.shape[0] == 5) 343 self.assertTrue(z.shape[1] == 4) 344 self.assertTrue(z.shape[2] == 3) 345 346 # mixed python symints and ints 347 z = y.expand(x.shape[0], y.shape[1], 3) 348 self.assertTrue(z.shape[0] == 5) 349 self.assertTrue(z.shape[1] == 4) 350 self.assertTrue(z.shape[2] == 3) 351 352 # mixed python symints and ints in a list 353 z = y.expand((x.shape[0], y.shape[1], 3)) 354 self.assertTrue(z.shape[0] == 5) 355 self.assertTrue(z.shape[1] == 4) 356 self.assertTrue(z.shape[2] == 3) 357 358 # mixed python symints and ints 359 z = y.expand(5, y.shape[1], x.shape[2]) 360 self.assertTrue(z.shape[0] == 5) 361 self.assertTrue(z.shape[1] == 4) 362 self.assertTrue(z.shape[2] == 3) 363 364 # mixed python ints and symints in a list 365 z = y.expand((5, y.shape[1], x.shape[2])) 366 self.assertTrue(z.shape[0] == 5) 367 self.assertTrue(z.shape[1] == 4) 368 self.assertTrue(z.shape[2] == 3) 369 370 z = y.expand((y.shape[1],)) 371 z = y.expand(y.shape[1]) 372 373 def test_stride(self): 374 shape_env = ShapeEnv() 375 x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) 376 self.assertIsInstance(x.stride()[0], SymInt) 377 378 def test_size_expressions(self): 379 shape_env = ShapeEnv() 380 x = create_symbolic_tensor("x", torch.randn(5), shape_env) 381 expand_x = x.expand(x.shape[0], x.shape[0]) 382 if expand_x.shape[0] > 3: 383 result = expand_x + expand_x 384 else: 385 result = expand_x + expand_x 386 387 gt_op, _bt = shape_env.guards[-1] 388 self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) 389 self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) 390 self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) 391 self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) 392 393 def test_floordiv_static(self): 394 shape_env = ShapeEnv() 395 s0 = create_symint(shape_env, 8) 396 # This was extracted from 397 # python test/inductor/test_cuda_cpp_wrapper.py -k 398 # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper 399 bool(s0 % 2 == 0) 400 bool(s0 % (s0 // 2) == 0) 401 bool(2 * (s0 // 2) == s0) 402 self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2)) 403 404 def test_numel(self): 405 shape_env = ShapeEnv() 406 x = create_symbolic_tensor("x", torch.randn(5), shape_env) 407 self.assertIsInstance(x.numel(), torch.SymInt) 408 self.assertIsInstance(torch.numel(x), torch.SymInt) 409 410 x = torch.rand(3, 3) 411 self.assertIsInstance(x.numel(), int) 412 self.assertIsInstance(torch.numel(x), int) 413 414 def test_int_to_float(self): 415 shape_env = ShapeEnv() 416 x = create_symbolic_tensor("x", torch.randn(5), shape_env) 417 r = torch.sym_float(x.shape[0]) 418 self.assertIsInstance(r, torch.SymFloat, msg=type(r)) 419 420 def test_aten_ops(self): 421 shape_env = ShapeEnv() 422 x = create_symbolic_tensor("x", torch.randn(5), shape_env) 423 torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) 424 425 shape_env = ShapeEnv() 426 x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env) 427 torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) 428 429 def test_fx_trace_intlist(self): 430 class CustomModule(torch.nn.Module): 431 def forward(self, x): 432 bs, c, h, w = x.shape 433 return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) 434 435 m = CustomModule() 436 x = torch.rand(1, 3, 4, 4) 437 # should not TypeError: pad(): argument 'pad' (position 2) must be 438 # tuple of ints, not tuple 439 torch.fx.symbolic_trace(m) 440 441 def test_meta_symint(self): 442 shape_env = ShapeEnv() 443 a0 = create_symint(shape_env, 2) 444 r = torch.empty(a0, device="meta") 445 self.assertIsInstance(r.shape[0], SymInt) 446 447 def test_guard_int(self): 448 shape_env = ShapeEnv() 449 a0 = create_symint(shape_env, 2) 450 self.assertEqual(guard_int(a0), 2) 451 self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 452 453 def test_prefer_deferred_runtime_assertions_over_guards(self): 454 shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) 455 s0 = create_symint(shape_env, 2) 456 self.assertEqual(guard_int(s0), 2) 457 self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 458 459 shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) 460 s0 = create_symint(shape_env, 2) 461 self.assertTrue(expect_true(s0 == 2)) 462 self.assertEqual(len(shape_env.guards), 0) 463 self.assertExpectedInline( 464 str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]), 465 """[Eq(s0, 2)]""", 466 ) 467 468 def test_sym_int(self): 469 shape_env = ShapeEnv() 470 a0 = create_symint(shape_env, 5) 471 r = sym_int(a0) 472 self.assertEqual(r, 5) 473 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 474 self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""") 475 476 a1 = create_symint(shape_env, 7) 477 r = sym_int(a1 / 2) 478 self.assertEqual(guard_int(r), 3) 479 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 480 self.assertExpectedInline( 481 str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" 482 ) 483 484 a3 = create_symint(shape_env, 3) 485 r = sym_int(2.0 * torch.sym_float(a3)) 486 self.assertEqual(guard_int(r), 6) 487 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 488 self.assertExpectedInline( 489 str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" 490 ) 491 492 def test_sym_sqrt(self): 493 shape_env = ShapeEnv() 494 a0 = create_symint(shape_env, 4) 495 r = torch._sym_sqrt(a0) 496 self.assertEqual(r, 2) 497 self.assertIsInstance(r, torch.SymFloat, msg=type(r)) 498 self.assertExpectedInline( 499 str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" 500 ) 501 502 def test_sym_floor(self): 503 shape_env = ShapeEnv() 504 a0 = create_symint(shape_env, 5) 505 r = math.floor(a0 / 2) 506 self.assertEqual(r, 2) 507 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 508 self.assertExpectedInline( 509 str(shape_env.guards[0][0]), 510 """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", 511 ) 512 r = math.floor(3.0 * a0) 513 self.assertEqual(r, 15) 514 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 515 self.assertExpectedInline( 516 str(shape_env.guards[1][0]), 517 """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", 518 ) 519 520 def test_sym_trunc(self): 521 shape_env = ShapeEnv() 522 a0 = create_symint(shape_env, 5) 523 r = math.trunc(a0 / 2) 524 self.assertEqual(r, 2) 525 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 526 self.assertExpectedInline( 527 str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" 528 ) 529 r = torch.sym_int(torch.sym_sqrt(a0)) 530 self.assertEqual(r, 2) 531 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 532 self.assertExpectedInline( 533 str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" 534 ) 535 536 def test_sym_ceil(self): 537 shape_env = ShapeEnv() 538 a0 = create_symint(shape_env, 5) 539 r = math.ceil(a0 / 2) 540 self.assertEqual(r, 3) 541 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 542 self.assertExpectedInline( 543 str(shape_env.guards[0][0]), 544 """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", 545 ) 546 r1 = 3.0 * a0 547 r = math.floor(r1) 548 self.assertEqual(r, 15) 549 self.assertIsInstance(r, torch.SymInt, msg=type(r)) 550 self.assertExpectedInline( 551 str(shape_env.guards[1][0]), 552 """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", 553 ) 554 555 def test_sym_ite(self): 556 shape_env = ShapeEnv() 557 t = create_symint(shape_env, 5) 558 f = create_symint(shape_env, 4) 559 b1 = True 560 r1 = torch.sym_ite(b1, t, f) 561 self.assertTrue(r1 is t) 562 b2 = False 563 r2 = torch.sym_ite(b2, t, f) 564 self.assertTrue(r2 is f) 565 b3 = t == 5 566 r3 = torch.sym_ite(b3, t, f) 567 self.assertEqual(len(shape_env.guards), 0) 568 self.assertEqual(r3, 5) 569 self.assertEqual(type(t), type(r3)) 570 self.assertExpectedInline( 571 str(shape_env.guards[0][0]), 572 """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""", 573 ) 574 b4 = f == 5 575 r4 = torch.sym_ite(b4, t, f) 576 self.assertEqual(len(shape_env.guards), 1) 577 self.assertEqual(r4, 4) 578 self.assertEqual(type(f), type(r4)) 579 self.assertExpectedInline( 580 str(shape_env.guards[1][0]), 581 """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""", 582 ) 583 584 def test_tracing_sym_ite(self): 585 def f(x): 586 b = x.shape[0] == 5 587 ret = torch.sym_ite(b, x.shape[0], x.shape[1]) 588 return ret 589 590 gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5)) 591 self.assertEqual(len(gm.shape_env.guards), 0) 592 self.assertExpectedInline( 593 gm.code.strip(), 594 """\ 595def forward(self, x_1): 596 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 597 eq = sym_size_int == 5 598 sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None 599 sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None 600 return sym_ite""", 601 ) 602 r1 = gm(torch.ones(4, 5)) 603 self.assertIsInstance(r1, int) 604 self.assertEqual(r1, 5) 605 r2 = gm(torch.ones(5, 4)) 606 self.assertIsInstance(r2, int) 607 self.assertEqual(r2, 5) 608 609 def test_int_conversion(self): 610 shape_env = ShapeEnv() 611 a0 = create_symint(shape_env, 2) 612 int(a0) 613 self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") 614 615 def test_data_dependent_guard(self): 616 shape_env = ShapeEnv() 617 s0 = shape_env.create_unbacked_symint() 618 self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) 619 620 def test_data_dependent_guard_propagate_real_tensors(self): 621 shape_env = ShapeEnv() 622 s0 = shape_env.create_unbacked_symint() 623 shape_env.set_unbacked_var_to_val(s0.node.expr, 0) 624 self.assertEqual(bool(s0 == 0), True) 625 626 def test_expect_true_basic(self): 627 shape_env = ShapeEnv() 628 i0 = shape_env.create_unbacked_symint() 629 i0_sym = i0.node.expr 630 # This doesn't error 631 self.assertTrue(expect_true(i0 == 0)) 632 # This generates a deferred runtime assert via replacement 633 self.assertEqual(shape_env.replacements[i0_sym], 0) 634 # After expecting true, guards now resolve given the runtime assert 635 bool(i0 == 0) 636 637 def test_expect_true_with_s0(self): 638 shape_env = ShapeEnv() 639 s0 = create_symint(shape_env, 5) 640 i0 = shape_env.create_unbacked_symint() 641 self.assertTrue(expect_true(i0 < s0)) 642 self.assertExpectedInline( 643 str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), 644 """[u0 < s0]""", 645 ) 646 self.assertTrue(i0 < s0) 647 self.assertTrue(i0 != s0) 648 self.assertFalse(i0 > s0) 649 self.assertFalse(i0 >= s0) 650 651 def test_expect_true_prefer_later(self): 652 shape_env = ShapeEnv() 653 i0 = shape_env.create_unbacked_symint() 654 i1 = shape_env.create_unbacked_symint() 655 i1_sym = i1.node.expr 656 self.assertTrue(expect_true(i0 + i1 == 10)) 657 # Importantly, this is put in i1, not i0! 658 self.assertExpectedInline( 659 str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]), 660 """[Eq(u0 + u1, 10)]""", 661 ) 662 self.assertTrue(i0 + i1 == 10) 663 # NB: We currently don't support deriving that we can substitute 664 # i0 + i1 with 10; maybe we should, but this means our rewriting 665 # system is no longer confluent (it's probably OK though, because 666 # you're unlikely to get other equalities like this on the 667 # unbacked SymInts.) 668 669 def test_unbacked_substitution(self): 670 shape_env = ShapeEnv() 671 i0 = shape_env.create_unbacked_symint() 672 i1 = shape_env.create_unbacked_symint() 673 _constrain_range_for_size(i0) 674 _constrain_range_for_size(i1) 675 self.assertTrue(expect_true(i0 == i1 * 4)) 676 self.assertExpectedInline(str(i0), """u0""") 677 678 i2 = shape_env.create_unbacked_symint() 679 i3 = shape_env.create_unbacked_symint() 680 _constrain_range_for_size(i2) 681 _constrain_range_for_size(i3) 682 self.assertTrue(expect_true(i2 * 4 == i3)) 683 self.assertExpectedInline(str(i3), """u3""") 684 685 def test_avoid_unbacked_substitution(self): 686 shape_env = ShapeEnv() 687 i0 = shape_env.create_unbacked_symint() 688 _constrain_range_for_size(i0) 689 i1 = shape_env.create_unbacked_symint() 690 _constrain_range_for_size(i1) 691 self.assertTrue(expect_true(i0 == 10 - i1)) 692 self.assertExpectedInline(str(i0), """u0""") 693 694 def test_expect_true_double_digits(self): 695 shape_env = ShapeEnv() 696 ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10 697 self.assertEqual(str(ia[-1]), "u10") 698 self.assertTrue(expect_true(sum(ia) == 20)) 699 self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1) 700 701 def test_expect_true_refine_range(self): 702 shape_env = ShapeEnv() 703 for i, rel in enumerate( 704 [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 705 ): 706 with self.subTest(f"i = {i}"): 707 i0 = shape_env.create_unbacked_symint() 708 self.assertTrue(expect_true(rel(i0))) 709 self.assertTrue(statically_known_true(i0 != 3)) 710 self.assertTrue(statically_known_true(i0 != 4)) 711 self.assertFalse(statically_known_true(i0 != 5)) 712 self.assertFalse(statically_known_true(i0 != 6)) 713 self.assertTrue(statically_known_true(i0 > 4)) 714 self.assertTrue(statically_known_true(i0 >= 5)) 715 716 for i, rel in enumerate( 717 [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 718 ): 719 with self.subTest(f"i = {i}"): 720 i0 = shape_env.create_unbacked_symint() 721 self.assertTrue(expect_true(rel(i0))) 722 self.assertFalse(statically_known_true(i0 != 2)) 723 self.assertFalse(statically_known_true(i0 != 3)) 724 self.assertTrue(statically_known_true(i0 != 4)) 725 self.assertTrue(statically_known_true(i0 != 5)) 726 self.assertTrue(statically_known_true(i0 < 4)) 727 self.assertTrue(statically_known_true(i0 <= 5)) 728 729 def test_guard_refine_range(self): 730 shape_env = ShapeEnv() 731 for i, rel in enumerate( 732 [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 733 ): 734 with self.subTest(f"i = {i}"): 735 i0 = create_symint(shape_env, 10, duck=False) 736 self.assertTrue(bool(rel(i0))) 737 self.assertTrue(statically_known_true(i0 != 3)) 738 self.assertTrue(statically_known_true(i0 != 4)) 739 self.assertFalse(statically_known_true(i0 != 5)) 740 self.assertFalse(statically_known_true(i0 != 6)) 741 self.assertTrue(statically_known_true(i0 > 4)) 742 self.assertTrue(statically_known_true(i0 >= 5)) 743 744 for i, rel in enumerate( 745 [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] 746 ): 747 with self.subTest(f"i = {i}"): 748 i0 = create_symint(shape_env, 2, duck=False) 749 self.assertFalse(bool(rel(i0))) 750 self.assertFalse(statically_known_true(i0 != 3)) 751 self.assertFalse(statically_known_true(i0 != 4)) 752 self.assertTrue(statically_known_true(i0 != 5)) 753 self.assertTrue(statically_known_true(i0 != 6)) 754 self.assertTrue(statically_known_true(i0 <= 4)) 755 self.assertTrue(statically_known_true(i0 < 5)) 756 757 for i, rel in enumerate( 758 [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 759 ): 760 with self.subTest(f"i = {i}"): 761 i0 = create_symint(shape_env, 2, duck=False) 762 self.assertTrue(bool(rel(i0))) 763 self.assertFalse(statically_known_true(i0 != 2)) 764 self.assertFalse(statically_known_true(i0 != 3)) 765 self.assertTrue(statically_known_true(i0 != 4)) 766 self.assertTrue(statically_known_true(i0 != 5)) 767 self.assertTrue(statically_known_true(i0 < 4)) 768 self.assertTrue(statically_known_true(i0 <= 3)) 769 770 for i, rel in enumerate( 771 [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] 772 ): 773 with self.subTest(f"i = {i}"): 774 i0 = create_symint(shape_env, 10, duck=False) 775 self.assertFalse(bool(rel(i0))) 776 self.assertTrue(statically_known_true(i0 != 2)) 777 self.assertTrue(statically_known_true(i0 != 3)) 778 self.assertFalse(statically_known_true(i0 != 4)) 779 self.assertFalse(statically_known_true(i0 != 5)) 780 self.assertTrue(statically_known_true(i0 >= 4)) 781 self.assertTrue(statically_known_true(i0 > 3)) 782 783 def test_mul_int_oo_nan(self): 784 shape_env = ShapeEnv() 785 s0 = create_symint(shape_env, 5, duck=False) 786 s1 = create_symint(shape_env, 6, duck=False) 787 s2 = create_symint(shape_env, 5, duck=False) 788 bool(s0 * (s1 // s0) == s2) 789 790 def test_non_overlapping_and_dense(self): 791 shape_env = ShapeEnv() 792 a0 = create_symint(shape_env, 5) 793 r = torch.empty_strided((a0, 7), (1, a0), device="meta") 794 self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) 795 796 def test_non_overlapping_and_dense_unbacked(self): 797 shape_env = ShapeEnv() 798 u0 = shape_env.create_unbacked_symint() 799 torch._check_is_size(u0) 800 cf = torch.ops.aten.is_non_overlapping_and_dense.default 801 802 self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1) 803 self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1) 804 self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) 805 self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) 806 807 self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1) 808 self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1) 809 self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) 810 self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) 811 812 Max = torch.sym_max 813 # NB: This only works because we're able to determine this tensor is 814 # contiguous. transpose(0, 1) makes it stop working 815 self.assertTrue( 816 cf( 817 torch.empty_strided( 818 (2, 3, 1, u0), 819 (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), 820 device="meta", 821 ) 822 ) 823 ) 824 825 def test_numpy_sym_max(self): 826 self.assertEqual(torch.sym_max(np.int64(10), 12), 12) 827 self.assertEqual(torch.sym_max(np.int64(12), 10), 12) 828 self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5) 829 self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0) 830 self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0) 831 self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0) 832 833 def test_numpy_sym_min(self): 834 self.assertEqual(torch.sym_min(np.int64(10), 12), 10) 835 self.assertEqual(torch.sym_min(np.int64(12), 10), 10) 836 self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0) 837 self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5) 838 self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0) 839 self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0) 840 841 def test_debug_has_internal_overlap_unbacked(self): 842 shape_env = ShapeEnv() 843 u0 = shape_env.create_unbacked_symint() 844 torch._check_is_size(u0) 845 cf = torch._debug_has_internal_overlap 846 self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) 847 self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) 848 self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) 849 self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0) 850 Max = torch.sym_max 851 self.assertEqual( 852 cf( 853 torch.empty_strided( 854 (2, 3, 1, u0), 855 (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), 856 device="meta", 857 ) 858 ), 859 0, 860 ) 861 862 # Wobbling these to zero is OK too 863 self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2) 864 self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2) 865 866 def test_specialize_zero_one(self): 867 shape_env = ShapeEnv(specialize_zero_one=True) 868 a0 = create_symint(shape_env, 5) 869 assert a0 != 1 870 self.assertEqual(len(shape_env.guards), 0) 871 872 shape_env = ShapeEnv(specialize_zero_one=False) 873 a0 = create_symint(shape_env, 5) 874 assert a0 != 1 875 self.assertEqual(len(shape_env.guards), 1) 876 877 def test_duck_shape(self): 878 shape_env = ShapeEnv(duck_shape=True) 879 a0 = create_symint(shape_env, 5) 880 a1 = create_symint(shape_env, 5) 881 assert a0 == a1 882 self.assertEqual(len(shape_env.guards), 0) 883 884 shape_env = ShapeEnv(duck_shape=False) 885 a0 = create_symint(shape_env, 5) 886 a1 = create_symint(shape_env, 5) 887 assert a0 == a1 888 self.assertEqual(len(shape_env.guards), 1) 889 890 def test_int_bool(self): 891 # See https://github.com/pytorch/pytorch/issues/95981 892 shape_env = ShapeEnv(duck_shape=True) 893 a0 = create_symint(shape_env, 5) 894 assert a0 895 self.assertEqual(len(shape_env.guards), 0) 896 897 def test_symint_as_scalar(self): 898 shape_env = ShapeEnv() 899 a0 = create_symint(shape_env, 2) 900 901 sym_int_encountered = False 902 903 class TestSymInt(TorchDispatchMode): 904 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 905 assert func == torch.ops.aten.add.Tensor 906 907 nonlocal sym_int_encountered 908 # WARNING: do not do identity tests on the outer 909 # SymInt/SymFloat, they are NOT STABLE 910 sym_int_encountered = kwargs["alpha"].node is a0.node 911 kwargs["alpha"] = 0 912 return func(*args) 913 914 x = torch.rand([4, 4]) 915 with TestSymInt(): 916 y = torch.add(x, x, alpha=a0) 917 918 self.assertTrue(sym_int_encountered) 919 920 def test_deepcopy(self): 921 shape_env = ShapeEnv() 922 a0 = create_symint(shape_env, 2) 923 assert a0 < 4 924 new_shape_env = copy.deepcopy(shape_env) 925 self.assertEqual(len(new_shape_env.guards), 1) 926 927 def test_print_readable_with_symints(self): 928 def f(a, b): 929 dim0 = a.shape[0] + b.shape[0] 930 dim1 = a.shape[1] + b.shape[1] 931 d = a.new_empty(dim0, dim1) 932 d = torch.ops.aten.native_dropout(d, 0.5, train=True) 933 return d 934 935 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) 936 out = fx_g.print_readable(print_output=False) 937 938 self.assertExpectedInline( 939 out.strip(), 940 """\ 941class f(torch.nn.Module): 942 def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"): 943 # No stacktrace found for following nodes 944 sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0) 945 sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0) 946 add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None 947 sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1) 948 sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None 949 add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None 950 new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None 951 native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None 952 getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0] 953 getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None 954 return (getitem, getitem_1)""", # noqa: B950 955 ) 956 957 def test_statically_known_true(self): 958 shape_env = ShapeEnv() 959 s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5)) 960 961 # Statically known true 962 self.assertTrue(statically_known_true(True)) 963 self.assertTrue(statically_known_true(s2 == s2)) 964 self.assertTrue(statically_known_true(s2 * s3 > s3)) 965 self.assertTrue(statically_known_true(s3 * s4 > s4)) 966 self.assertTrue(statically_known_true((s3 + s3) % 2 == 0)) 967 968 # Statically known false 969 self.assertFalse(statically_known_true(False)) 970 self.assertFalse(statically_known_true(s3 * s4 <= s4)) 971 self.assertFalse(statically_known_true((s3 + s3) % 2 == 1)) 972 973 # True for hints, but not known statically 974 self.assertFalse(statically_known_true(s2 + s2 == s4)) 975 self.assertFalse(statically_known_true(s4 % s2 == 0)) 976 self.assertFalse(statically_known_true(s2 != s3)) 977 self.assertFalse(statically_known_true(s3 * s4 > s2)) 978 979 # False for hints, but not known statically 980 self.assertFalse(statically_known_true(s2 == s3)) 981 self.assertFalse(statically_known_true(s2 > s3)) 982 self.assertFalse(statically_known_true(s3 + s3 == s4)) 983 984 # No guards should be generated 985 self.assertEqual(len(shape_env.guards), 0) 986 987 def test_ephemeral_source_simplification(self): 988 from torch._dynamo.source import EphemeralSource 989 990 # For full robustness, ensure the ephemeral source symbols are simplified out regardless 991 # of construction order or check order. 992 for construct_ephemeral_first, x_first_in_check in itertools.product( 993 [False, True], [False, True] 994 ): 995 shape_env = ShapeEnv() 996 shape = (5, 10) 997 dynamic_dims = [DimDynamic.DYNAMIC for _ in shape] 998 x = create_symbolic_tensor( 999 "x", 1000 torch.randn(*shape), 1001 shape_env, 1002 source=(EphemeralSource() if construct_ephemeral_first else None), 1003 dynamic_dims=dynamic_dims, 1004 ) 1005 y = create_symbolic_tensor( 1006 "y", 1007 torch.randn(*shape), 1008 shape_env, 1009 source=(EphemeralSource() if not construct_ephemeral_first else None), 1010 dynamic_dims=dynamic_dims, 1011 ) 1012 t_with_ephemeral = x if construct_ephemeral_first else y 1013 1014 def _get_ephemeral_source_symbols(t): 1015 return [ 1016 s.node.expr 1017 for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),)) 1018 if isinstance(s, torch.SymInt) 1019 and s.node.expr in shape_env.var_to_sources 1020 and any( 1021 source.is_ephemeral() 1022 for source in shape_env.var_to_sources[s.node.expr] 1023 ) 1024 ] 1025 1026 # these checks should simplify out the ephemeral symbols, regardless of the 1027 # ordering x == y or y == x 1028 self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0) 1029 if x_first_in_check: 1030 torch._check(x.size() == y.size()) 1031 torch._check(x.stride() == y.stride()) 1032 torch._check(x.storage_offset() == y.storage_offset()) 1033 else: 1034 torch._check(y.size() == x.size()) 1035 torch._check(y.stride() == x.stride()) 1036 torch._check(y.storage_offset() == x.storage_offset()) 1037 self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0) 1038 1039 def test_ephemeral_source_unified_with_non_ephemeral_source(self): 1040 from torch._dynamo.source import EphemeralSource 1041 1042 for construct_ephemeral_first in (False, True): 1043 shape_env = ShapeEnv() 1044 shape = (5, 10) 1045 # use duck sizing here to ensure symbol reuse across x and y 1046 duck_dims = [DimDynamic.DUCK for _ in shape] 1047 x = create_symbolic_tensor( 1048 "x", 1049 torch.randn(*shape), 1050 shape_env, 1051 source=(EphemeralSource() if construct_ephemeral_first else None), 1052 dynamic_dims=duck_dims, 1053 ) 1054 y = create_symbolic_tensor( 1055 "y", 1056 torch.randn(*shape), 1057 shape_env, 1058 source=(EphemeralSource() if not construct_ephemeral_first else None), 1059 dynamic_dims=duck_dims, 1060 ) 1061 1062 # regardless of construction order, non-ephemeral sources should be preferred 1063 # first in the var_to_sources list for potential guarding later on 1064 for source_list in shape_env.var_to_sources.values(): 1065 self.assertFalse(source_list[0].is_ephemeral()) 1066 1067 self.assertEqual(x.size(), y.size()) 1068 self.assertEqual(x.stride(), y.stride()) 1069 self.assertEqual(x.storage_offset(), y.storage_offset()) 1070 1071 1072@skipIfTorchDynamo( 1073 "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" 1074) 1075class TestSymNumberMagicMethods(TestCase): 1076 def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): 1077 with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): 1078 return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) 1079 1080 def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): 1081 # Helper function 1082 # NB: don't use one as that will get specialized 1083 # TODO: We don't have to circuitously create the float, can just 1084 # create a symfloat directly 1085 seed_node = (create_symint(shape_env, 2) / 2.0).node 1086 bool_seed_node = (create_symint(shape_env, 2) == 2).node 1087 1088 def get_sym_inp(inp): 1089 # NB: this must come before int 1090 if isinstance(inp, bool): 1091 return torch.SymBool(to_node(bool_seed_node, inp)) 1092 elif isinstance(inp, int): 1093 return torch.SymInt(to_node(seed_node, inp)) 1094 else: 1095 return torch.SymFloat(to_node(seed_node, inp)) 1096 1097 if fn == "float_pow": 1098 if inp1 < 0: 1099 return 1100 1101 if fn == "pow_by_natural": 1102 if isinstance(inp1, float) or isinstance(inp2, float): 1103 return 1104 if inp2 < 0: 1105 return 1106 1107 def maybe_xfail(inp1, inp2): 1108 if fn == "sym_sqrt" and inp1 < 0: 1109 # ValueError: math domain error 1110 return self.assertRaises((ValueError,)) 1111 elif ( 1112 fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") 1113 and inp2 == 0 1114 ): 1115 # ZeroDivisionError: division by zero 1116 return self.assertRaises((ZeroDivisionError,)) 1117 elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: 1118 # ZeroDivisionError: 0.0 cannot be raised to a negative power 1119 return self.assertRaises((ZeroDivisionError,)) 1120 elif ( 1121 # TODO: dear catastrophe waitress, 1122 # this doesn't work 1123 fn in ["float_pow", "pow_by_natural"] 1124 and inp1 < 0 1125 and ( 1126 type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) 1127 ) 1128 and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) 1129 ): 1130 # Complex result, which we do not support: 1131 # TypeError: Cannot convert complex to float 1132 return self.assertRaises((RuntimeError,)) 1133 elif fn in ("lshift", "rshift") and not ( 1134 isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) 1135 ): 1136 # TypeError: unsupported operand type(s) 1137 return self.assertRaises((TypeError,)) 1138 elif fn in ("lshift", "rshift") and inp2 < 0: 1139 # ValueError: math domain error 1140 return self.assertRaises((ValueError,)) 1141 else: 1142 return contextlib.nullcontext() 1143 1144 lambda_apply = method_to_operator(fn) 1145 1146 def guard_fn(v): 1147 if type(v) in (SymBool, bool): 1148 return guard_bool(v) 1149 elif type(v) in (SymFloat, float): 1150 return guard_float(v) 1151 else: # SymInt, int 1152 return guard_int(v) 1153 1154 # Get reference result 1155 with maybe_xfail(inp1, inp2): 1156 if is_unary_fn: 1157 ref_out = lambda_apply(inp1) 1158 else: 1159 ref_out = lambda_apply(inp1, inp2) 1160 1161 # Symified first arg 1162 sym_inp1 = get_sym_inp(inp1) 1163 with maybe_xfail(sym_inp1, inp2): 1164 if is_unary_fn: 1165 out = lambda_apply(sym_inp1) 1166 else: 1167 out = lambda_apply(sym_inp1, inp2) 1168 if fn not in sym_node.alternate_impl_if_hinted_methods: 1169 self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1170 out = guard_fn(out) 1171 self.assertEqual(out, ref_out) 1172 1173 if is_unary_fn: 1174 return 1175 1176 # Symified second arg 1177 sym_inp2 = get_sym_inp(inp2) 1178 with maybe_xfail(inp1, sym_inp2): 1179 out = lambda_apply(inp1, sym_inp2) 1180 if fn not in sym_node.alternate_impl_if_hinted_methods: 1181 self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1182 out = guard_fn(out) 1183 self.assertEqual(out, ref_out) 1184 1185 # Symified both args 1186 with maybe_xfail(sym_inp1, sym_inp2): 1187 out = lambda_apply(sym_inp1, sym_inp2) 1188 if fn not in sym_node.alternate_impl_if_hinted_methods: 1189 self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) 1190 out = guard_fn(out) 1191 self.assertEqual(out, ref_out) 1192 1193 @parametrize("fn", list(sym_node.magic_methods.keys())) 1194 def test_bool_method(self, fn): 1195 # sym_ite has its own tests 1196 if fn not in sym_node.bool_magic_methods or fn == "sym_ite": 1197 self.skipTest(f"{fn} is non-bool") 1198 1199 is_unary_fn = fn in sym_node.unary_methods 1200 shape_env = ShapeEnv() 1201 self._do_test(fn, True, False, shape_env, is_unary_fn) 1202 1203 @parametrize("fn", list(sym_node.magic_methods.keys())) 1204 @parametrize("first_type", ["int", "float"]) 1205 @parametrize("second_type", ["int", "float"]) 1206 def test_method(self, fn, first_type, second_type): 1207 if first_type == "float": 1208 # TODO: Hmm, this looks like we skip all floats 1209 self.skipTest(f"{fn} is not a float magic method") 1210 1211 if ( 1212 first_type == "int" or second_type == "int" 1213 ) and fn in sym_node.only_float_magic_methods: 1214 self.skipTest(f"{fn} is not an int method") 1215 1216 if second_type == "float" and fn in ["mod"]: 1217 self.skipTest(f"{fn} only handles int") 1218 1219 is_unary_fn = fn in sym_node.unary_methods or fn == "round" 1220 # Second argument is ignored for unary function. So only run for one type 1221 if is_unary_fn and second_type == "float": 1222 self.skipTest(f"{fn} is unary and already tested") 1223 1224 if fn in sym_node.bool_magic_methods: 1225 self.skipTest(f"{fn} is bool") 1226 1227 # Only floats here since these will be converted to int if necessary. 1228 # We also ignore complex and bool. 1229 values = ( 1230 0.0, 1231 1.0, 1232 0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error 1233 ) 1234 1235 neg_values = tuple(-x for x in values) 1236 1237 for inp1, inp2 in itertools.chain( 1238 itertools.product(values, values), 1239 itertools.product(values, neg_values), 1240 itertools.product(neg_values, values), 1241 itertools.product(neg_values, neg_values), 1242 ): 1243 if first_type == "int": 1244 inp1 = int(inp1) 1245 if second_type == "int": 1246 inp2 = int(inp2) 1247 1248 shape_env = ShapeEnv() 1249 1250 self._do_test(fn, inp1, inp2, shape_env, is_unary_fn) 1251 1252 def get_constant_bool(self, val): 1253 return SymBool(torch._C._get_constant_bool_symnode(val)) 1254 1255 @unittest.expectedFailure 1256 def test_symint_hashing(self): 1257 shape_env = ShapeEnv() 1258 hash(create_symint(shape_env, 3)) 1259 1260 def test_symnode_hashing(self): 1261 shape_env = ShapeEnv() 1262 1263 # These all trigger specialization when hashed 1264 hash(create_symbool(shape_env, True)) 1265 # We should be passing in float here, but create_symbol currently 1266 # only supports int 1267 hash(create_symfloat(shape_env, 3.0)) 1268 1269 # NestedInt (SymInt), constant SymBool, SymNode are hashable 1270 j1 = torch._C._get_nested_int(1, 1) 1271 j1_copy = torch._C._get_nested_int(1, 1) 1272 j2 = torch._C._get_nested_int(2, 1) 1273 t = self.get_constant_bool(True) 1274 t_copy = self.get_constant_bool(True) 1275 f = self.get_constant_bool(False) 1276 n = create_symint(shape_env, 3).node 1277 m = self.get_constant_bool(True).node 1278 1279 self.assertIs(j1 == j1_copy, True) 1280 self.assertEqual(hash(j1), hash(j1_copy)) 1281 self.assertIs(j1 == j2, False) 1282 self.assertNotEqual(hash(j1), hash(j2)) 1283 self.assertIs(t == t_copy, True) 1284 self.assertEqual(hash(t), hash(t_copy)) 1285 self.assertIs(t == f, False) 1286 self.assertNotEqual(hash(t), hash(f)) 1287 1288 hash(n) 1289 hash(m) 1290 1291 def test_symint_deepcopy(self): 1292 shape_env = ShapeEnv() 1293 1294 symnodes = (torch._C._get_nested_int(1, 1),) 1295 deepcopied_symnodes = copy.deepcopy(symnodes) 1296 self.assertEqual(symnodes, deepcopied_symnodes) 1297 1298 def test_non_symbolic_symnode(self): 1299 j1 = torch._C._get_nested_int(1, 1) 1300 j2 = torch._C._get_nested_int(1, 1) 1301 j3 = torch._C._get_nested_int(3, 1) 1302 1303 self.assertIsInstance(j1, torch.SymInt) 1304 self.assertNotIsInstance(j1, int) 1305 1306 with self.assertRaisesRegex( 1307 RuntimeError, "add not supported by NestedIntSymNode" 1308 ): 1309 j1 + 3 1310 1311 self.assertFalse(j1 == 3) 1312 with self.assertRaisesRegex(RuntimeError, "indeterminate"): 1313 self.assertFalse(3 >= j2) 1314 1315 self.assertIs(j1 == j1, True) 1316 self.assertIs(j1 == j2, True) 1317 self.assertIs(j1 == j3, False) 1318 self.assertIs(j1 != j3, True) 1319 self.assertIs(j1 != j2, False) 1320 1321 x = self.get_constant_bool(True) 1322 # 1323 # Unary 1324 # 1325 # op(constant SymBool) 1326 self.assertIs(x.__sym_not__(), False) 1327 1328 # 1329 # Binary 1330 # 1331 # op(constant SymBool, bool) 1332 # op(constant SymBool, constant SymBool) 1333 # op(bool, constant SymBool) 1334 self.assertIs(operator.and_(x, True), True) 1335 self.assertIs(operator.and_(x, x), True) 1336 self.assertIs(operator.and_(True, x), True) 1337 1338 # op(symbolic SymBool, constant Symbool) 1339 # op(constant SymBool, symbolic Symbool) 1340 shape_env = ShapeEnv() 1341 a = create_symint(shape_env, 2) 1342 b = create_symint(shape_env, 2) 1343 c = a == b # symbolic SymBool 1344 d = self.get_constant_bool(True) 1345 e = operator.and_(c, d) 1346 f = operator.and_(d, c) 1347 self.assertTrue(is_symbolic(e)) 1348 self.assertTrue(is_symbolic(f)) 1349 self.assertIs(e.node.guard_bool("", 0), True) 1350 self.assertIs(f.node.guard_bool("", 0), True) 1351 1352 # Comparing sizes 1353 sz1 = torch.Size([j1, j1, j1]) 1354 sz2 = torch.Size([j1, j1, j1]) 1355 self.assertIs(sz1 == sz2, True) 1356 1357 sz1 = torch.Size([3, j1, 4]) 1358 sz2 = torch.Size([3, j2, 4]) 1359 self.assertIs(sz1 == sz2, True) 1360 self.assertIs(sz1 != sz2, False) 1361 1362 def test_stride_symnode(self): 1363 from torch._subclasses.fake_tensor import FakeTensorMode 1364 1365 shape_env = ShapeEnv() 1366 1367 def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides): 1368 with FakeTensorMode(shape_env=shape_env) as fake_mode: 1369 return fake_mode.from_tensor( 1370 x, 1371 symbolic_context=StatelessSymbolicContext( 1372 dynamic_sizes=dynamic_sizes, 1373 dynamic_strides=dynamic_strides, 1374 ), 1375 ) 1376 1377 # check everything static 1378 t = _create_symbolic_tensor( 1379 x=torch.ones(3, 6), 1380 dynamic_sizes=[ 1381 DimDynamic.STATIC, 1382 DimDynamic.STATIC, 1383 ], 1384 dynamic_strides=[ 1385 DimDynamic.INFER_STRIDE, 1386 DimDynamic.INFER_STRIDE, 1387 ], 1388 ) 1389 self.assertTrue(all(isinstance(size, int) for size in t.size())) 1390 self.assertTrue(all(isinstance(stride, int) for stride in t.stride())) 1391 1392 # check dynamic size but static dims 1393 t = _create_symbolic_tensor( 1394 x=torch.ones(3, 6), 1395 dynamic_sizes=[ 1396 DimDynamic.DYNAMIC, 1397 DimDynamic.DYNAMIC, 1398 ], 1399 dynamic_strides=[ 1400 DimDynamic.INFER_STRIDE, 1401 DimDynamic.INFER_STRIDE, 1402 ], 1403 ) 1404 # Expect stride to be inferred 1405 s0, s1 = t.size() 1406 s2, s3 = t.stride() 1407 self.assertTrue(isinstance(s0, torch.SymInt)) 1408 self.assertTrue(isinstance(s1, torch.SymInt)) 1409 self.assertTrue(isinstance(s2, torch.SymInt)) 1410 self.assertTrue(s1 == s2) 1411 self.assertEqual(s3, 1) 1412 1413 # Check dynamic stride but static dims 1414 t = _create_symbolic_tensor( 1415 x=torch.ones(3, 6), 1416 dynamic_sizes=[ 1417 DimDynamic.STATIC, 1418 DimDynamic.STATIC, 1419 ], 1420 dynamic_strides=[ 1421 DimDynamic.DYNAMIC, 1422 DimDynamic.INFER_STRIDE, 1423 ], 1424 ) 1425 s0, s1 = t.size() 1426 s2, s3 = t.stride() 1427 self.assertTrue(isinstance(s0, int)) 1428 self.assertTrue(isinstance(s1, int)) 1429 self.assertTrue(isinstance(s2, torch.SymInt)) 1430 self.assertTrue(isinstance(s3, int)) 1431 1432 # Check dynamic sizes and dims, and ensure different symbol 1433 t = _create_symbolic_tensor( 1434 x=torch.ones(3, 6), 1435 dynamic_sizes=[ 1436 DimDynamic.DYNAMIC, 1437 DimDynamic.DYNAMIC, 1438 ], 1439 dynamic_strides=[ 1440 DimDynamic.DYNAMIC, 1441 DimDynamic.INFER_STRIDE, 1442 ], 1443 ) 1444 s0, s1 = t.size() 1445 s2, s3 = t.stride() 1446 self.assertTrue(isinstance(s0, torch.SymInt)) 1447 self.assertTrue(isinstance(s1, torch.SymInt)) 1448 self.assertTrue(isinstance(s2, torch.SymInt)) 1449 self.assertTrue(isinstance(s3, int)) 1450 self.assertTrue(str(s1.node.expr) != str(s2.node.expr)) 1451 1452 1453instantiate_parametrized_tests(TestSymNumberMagicMethods) 1454 1455 1456class TestFloorDiv(TestCase): 1457 @staticmethod 1458 def python_floordiv(x, y): 1459 return x // y 1460 1461 @staticmethod 1462 def torch_floordiv(x, y): 1463 # Note: we fully evaluate here since FloorDiv might not always do 1464 # that. 1465 shape_env = ShapeEnv() 1466 return shape_env.evaluate_expr(FloorDiv(x, y)) 1467 1468 @staticmethod 1469 def yield_test_cases(values, negate=True): 1470 for x, y in values: 1471 yield (x, y) 1472 if negate: 1473 yield (-x, y) 1474 yield (x, -y) 1475 yield (-x, -y) 1476 1477 def test_floordiv_float_int(self): 1478 values = ((7, 2),) 1479 1480 for x, y in TestFloorDiv.yield_test_cases(values): 1481 self.assertEqual( 1482 TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) 1483 ) 1484 1485 def test_floordiv_div_by_one(self): 1486 values = ((2, 1),) 1487 1488 for x, y in TestFloorDiv.yield_test_cases(values): 1489 self.assertEqual( 1490 TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) 1491 ) 1492 1493 def test_floordiv_simplify(self): 1494 # Tests how we simplify or evaluate FloorDiv without free variables 1495 shape_env = ShapeEnv() 1496 result = 21 1497 exprs = (7 * FloorDiv(6, 2),) 1498 1499 for expr in exprs: 1500 self.assertEqual(expr, result) 1501 self.assertEqual(expr.doit(deep=False), result) 1502 self.assertEqual(expr.doit(deep=True), result) 1503 self.assertEqual(sympy.simplify(expr), result) 1504 self.assertEqual(shape_env.simplify(expr), result) 1505 self.assertEqual(shape_env.evaluate_expr(expr), result) 1506 1507 def test_floordiv_assumptions(self): 1508 cases = ( 1509 sympy.Symbol("i1", integer=True), 1510 sympy.Symbol("i2", integer=True), 1511 ) 1512 1513 for base, divisor in itertools.product(cases, repeat=2): 1514 1515 def op(): 1516 return FloorDiv(base, divisor) 1517 1518 def is_complex(x): 1519 return x.is_integer is False and x.is_real is False and x.is_complex 1520 1521 if is_complex(base) or is_complex(divisor): 1522 self.assertRaisesRegex( 1523 TypeError, 1524 ( 1525 r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol'," 1526 r" expected integer or real" 1527 ), 1528 op, 1529 ) 1530 continue 1531 1532 op = op() 1533 1534 # In regular Python, x//x == 1.0 if x is a float, but FloorDiv 1535 # always returns an integer 1 when both args are the same object. 1536 # This even works for Symbols with no assumptions specified. 1537 if base is divisor: 1538 self.assertTrue(op.is_integer) 1539 self.assertTrue(op.is_real) 1540 elif base.is_integer and divisor.is_integer: 1541 self.assertTrue(op.is_integer) 1542 self.assertTrue(op.is_real) 1543 else: 1544 self.assertEqual(op.is_integer, None) 1545 self.assertTrue(op.is_real) 1546 1547 1548class TestDimConstraints(TestCase): 1549 def test_dim_constraints_reduce_congruences_simple(self): 1550 from sympy import Symbol 1551 1552 s = Symbol("s", positive=True, integer=True) 1553 dim_constraints = DimConstraints({}, {}, set(), {}) 1554 dim_constraints._congruences[s] = { 1555 (s / 2) % 2, 1556 (s / 2) % 8, 1557 (s / 2) % 4, 1558 s % 2, 1559 ((s / 16) + 2) % 4, 1560 } 1561 congruences = dim_constraints._reduce_congruences() 1562 self.assertEqual(congruences[s], {(s + 32) % 64}) 1563 1564 def test_dim_constraints_reduce_inequalities_simple(self): 1565 from sympy import Eq, Interval, Ne, Symbol 1566 from sympy.solvers.inequalities import reduce_inequalities 1567 1568 s = Symbol("s", positive=True, integer=True) 1569 exprs = { 1570 s >= 2, 1571 Ne(8 * s, 16), 1572 Ne(s / 2, 1), 1573 Ne(16 * s, 32), 1574 s < 16, 1575 Ne(s, 2), 1576 s / 2 < 16, 1577 s / 2 > 1, 1578 s / 2 >= 2, 1579 Ne(3 * s / 2, 3), 1580 } 1581 solution = reduce_inequalities(exprs, s).as_set() 1582 self.assertEqual(solution, Interval.Ropen(4, 16)) 1583 1584 exprs.add(Eq(s / 2, 4)) 1585 solution = reduce_inequalities(exprs, s).as_set() 1586 self.assertEqual(solution, {8}) 1587 1588 def test_dim_constraints_reduce_inequalities_error(self): 1589 from collections import defaultdict 1590 1591 from sympy import Symbol 1592 from sympy.solvers.inequalities import reduce_inequalities 1593 1594 from torch._dynamo.source import ( 1595 LocalSource, 1596 TensorProperty, 1597 TensorPropertySource, 1598 ) 1599 from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter 1600 1601 s0 = Symbol("s0", positive=True, integer=True) 1602 exprs = { 1603 4 * s0**3 - 4 * s0**2 + s0 <= 2147483647, 1604 s0 >= 2, 1605 s0**3 <= 2147483647, 1606 s0 <= 2147483647, 1607 } 1608 answer = reduce_inequalities(exprs, s0) 1609 1610 symbol_to_source = defaultdict(list) 1611 symbol_to_source[s0].append( 1612 TensorPropertySource( 1613 base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 1614 ) 1615 ) 1616 dcp = DynamicDimConstraintPrinter(symbol_to_source, {}) 1617 with self.assertRaisesRegex( 1618 AssertionError, 1619 "Unknown symbol.*created by constraints solver", 1620 ): 1621 dcp.doprint(answer) 1622 1623 def test_dim_constraints_solve_full(self): 1624 from sympy import Eq, Integer, Ne, Symbol 1625 1626 from torch._dynamo.source import ( 1627 LocalSource, 1628 TensorProperty, 1629 TensorPropertySource, 1630 ) 1631 1632 src0 = TensorPropertySource( 1633 base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 1634 ) 1635 src2 = TensorPropertySource( 1636 base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0 1637 ) 1638 src3 = TensorPropertySource( 1639 base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0 1640 ) 1641 src4 = TensorPropertySource( 1642 base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0 1643 ) 1644 1645 src1 = TensorPropertySource( 1646 base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2 1647 ) 1648 src7 = TensorPropertySource( 1649 base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3 1650 ) 1651 1652 src5 = TensorPropertySource( 1653 base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1 1654 ) 1655 src8 = TensorPropertySource( 1656 base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1 1657 ) 1658 1659 src6 = TensorPropertySource( 1660 base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1 1661 ) 1662 src9 = TensorPropertySource( 1663 base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1 1664 ) 1665 src10 = TensorPropertySource( 1666 base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1 1667 ) 1668 1669 src11 = TensorPropertySource( 1670 base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1 1671 ) 1672 src12 = TensorPropertySource( 1673 base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2 1674 ) 1675 1676 s0 = Symbol("s0", positive=True, integer=True) 1677 s1 = Symbol("s1", positive=True, integer=True) 1678 s5 = Symbol("s5", positive=True, integer=True) 1679 s6 = Symbol("s6", positive=True, integer=True) 1680 symbol_to_source = { 1681 s0: [src0, src2, src3, src4], 1682 s1: [src1, src7], 1683 s5: [src5, src8], 1684 s6: [src6, src9, src10], 1685 } 1686 var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21} 1687 marked_dynamic = {s0, s1, s5, s6} 1688 dim_constraints = DimConstraints( 1689 symbol_to_source, var_to_val, marked_dynamic, {} 1690 ) 1691 dim_constraints.add_equality(src2, s0) 1692 dim_constraints.add_equality(src3, s0) 1693 dim_constraints.add_equality(src4, s0) 1694 dim_constraints.add_equality(src7, s1) 1695 dim_constraints.add_equality(src8, s5) 1696 dim_constraints.add_equality(src9, s6) 1697 dim_constraints.add_equality(src10, s6) 1698 dim_constraints.add_equality(src11, Integer(1)) 1699 dim_constraints.add_equality(src12, Integer(3)) 1700 1701 dim_constraints.add(s1**2 <= 2147483647) 1702 dim_constraints.add(32 * s1**2 <= 2147483647) 1703 dim_constraints.add(s0 < 16) 1704 dim_constraints.add(Eq(Mod(s1, 2), 0)) 1705 dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) 1706 dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1)) 1707 dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647) 1708 dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1) 1709 dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) 1710 dim_constraints.add( 1711 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1712 + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1713 + 64 1714 <= 2147483647 1715 ) 1716 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) 1717 dim_constraints.add( 1718 Ne( 1719 (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1720 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1721 + 1, 1722 1, 1723 ) 1724 ) 1725 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) 1726 dim_constraints.add( 1727 (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 1728 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) 1729 + 1 1730 > 1 1731 ) 1732 dim_constraints.add( 1733 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1734 + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1735 + 128 1736 <= 2147483647 1737 ) 1738 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) 1739 dim_constraints.add( 1740 Ne( 1741 (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1742 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1743 + 1, 1744 1, 1745 ) 1746 ) 1747 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) 1748 dim_constraints.add( 1749 (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 1750 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) 1751 + 1 1752 > 1 1753 ) 1754 dim_constraints.add( 1755 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1756 + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1757 + 256 1758 <= 2147483647 1759 ) 1760 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) 1761 dim_constraints.add( 1762 Ne( 1763 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1764 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1765 + 1, 1766 1, 1767 ) 1768 ) 1769 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) 1770 dim_constraints.add( 1771 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1772 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1773 + 1 1774 > 1 1775 ) 1776 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3) 1777 dim_constraints.add( 1778 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1779 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1780 + 60 1781 <= 2147483647 1782 ) 1783 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0) 1784 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1) 1785 dim_constraints.add( 1786 Ne( 1787 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1788 - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1789 + 60 * s0, 1790 0, 1791 ) 1792 ) 1793 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) 1794 dim_constraints.add( 1795 Ne( 1796 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1797 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1798 + 1, 1799 1, 1800 ) 1801 ) 1802 dim_constraints.add( 1803 Ne( 1804 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1805 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1806 + 1, 1807 0, 1808 ) 1809 ) 1810 dim_constraints.add( 1811 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1812 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1813 + 1 1814 >= 0 1815 ) 1816 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0)) 1817 dim_constraints.add( 1818 1 1819 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1820 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1821 + 60 1822 ) 1823 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1)) 1824 dim_constraints.add( 1825 Ne( 1826 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1827 - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1828 + 60 * s0, 1829 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1830 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1831 + 120, 1832 ) 1833 ) 1834 dim_constraints.add( 1835 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1836 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1837 + 120 1838 > 0 1839 ) 1840 dim_constraints.add( 1841 Eq( 1842 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2)) 1843 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2) 1844 + 60 * (Mod(s0, 2)), 1845 0, 1846 ) 1847 ) 1848 dim_constraints.add( 1849 Ne( 1850 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1851 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1852 + 120, 1853 0, 1854 ) 1855 ) 1856 dim_constraints.add( 1857 Ne( 1858 60 1859 * (FloorDiv(s0, 2)) 1860 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1861 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1862 - 120 1863 * FloorDiv(s0, 2) 1864 * FloorDiv(s0, (FloorDiv(s0, 2))) 1865 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1866 + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1867 0, 1868 ) 1869 ) 1870 dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) 1871 dim_constraints.add( 1872 Ne( 1873 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1874 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1875 + 60, 1876 0, 1877 ) 1878 ) 1879 dim_constraints.add( 1880 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1881 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1882 + 60 1883 >= 0 1884 ) 1885 dim_constraints.add( 1886 1 1887 < 60 1888 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1889 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1890 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1891 + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1892 ) 1893 dim_constraints.add(Ne(16 * s0, 32)) 1894 dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) 1895 dim_constraints.add(Ne(16 * s0, 32)) 1896 dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) 1897 dim_constraints.add(FloorDiv(s0, 2) >= 2) 1898 dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) 1899 dim_constraints.add(1 < FloorDiv(s0, 2)) 1900 dim_constraints.add(Ne(s0, 2)) 1901 dim_constraints.add( 1902 60 1903 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1904 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1905 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1906 + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1907 >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1908 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1909 + 60 1910 ) 1911 dim_constraints.add( 1912 60 1913 * (FloorDiv(s0, 2)) 1914 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1915 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1916 - 120 1917 * FloorDiv(s0, 2) 1918 * FloorDiv(s0, (FloorDiv(s0, 2))) 1919 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1920 + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1921 > 0 1922 ) 1923 dim_constraints.add( 1924 Ne( 1925 60 1926 * (FloorDiv(s0, 2)) 1927 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 1928 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1929 - 120 1930 * FloorDiv(s0, 2) 1931 * FloorDiv(s0, (FloorDiv(s0, 2))) 1932 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1933 + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1934 3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 1935 ) 1936 ) 1937 dim_constraints.add( 1938 Ne( 1939 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1940 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1941 + 20, 1942 0, 1943 ) 1944 ) 1945 dim_constraints.add( 1946 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1947 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1948 + 20 1949 >= 0 1950 ) 1951 dim_constraints.add( 1952 Ne( 1953 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1954 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1955 + 20, 1956 20, 1957 ) 1958 ) 1959 dim_constraints.add( 1960 Ne( 1961 20 1962 * ( 1963 Mod( 1964 1, 1965 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1966 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1967 + 1, 1968 ) 1969 ), 1970 0, 1971 ) 1972 ) 1973 dim_constraints.add( 1974 Ne( 1975 20 1976 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 1977 * ( 1978 Mod( 1979 1, 1980 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1981 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1982 - 2 1983 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1984 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1985 + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1986 ) 1987 ) 1988 - 20 1989 * Mod( 1990 1, 1991 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 1992 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1993 - 2 1994 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 1995 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 1996 + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1997 ), 1998 0, 1999 ) 2000 ) 2001 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) 2002 dim_constraints.add( 2003 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2004 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2005 + 1 2006 >= 1 2007 ) 2008 dim_constraints.add( 2009 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2010 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2011 + 20 2012 >= 0 2013 ) 2014 dim_constraints.add( 2015 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2016 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2017 + 20 2018 >= 1 2019 ) 2020 dim_constraints.add( 2021 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2022 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2023 + 20 2024 >= 2 2025 ) 2026 dim_constraints.add( 2027 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2028 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2029 + 20 2030 > 1 2031 ) 2032 dim_constraints.add( 2033 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2034 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2035 + 20 2036 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2037 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2038 + 60 2039 ) 2040 dim_constraints.add( 2041 Ne( 2042 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2043 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2044 + 60, 2045 60, 2046 ) 2047 ) 2048 dim_constraints.add( 2049 Ne( 2050 FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 2051 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2052 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2053 + 1, 2054 ) 2055 ) 2056 dim_constraints.add( 2057 Eq( 2058 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) 2059 * ( 2060 Mod( 2061 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2062 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2063 - 2 2064 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2065 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2066 + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 2067 1, 2068 ) 2069 ) 2070 - Mod( 2071 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2072 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2073 - 2 2074 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2075 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) 2076 + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 2077 1, 2078 ), 2079 0, 2080 ) 2081 ) 2082 dim_constraints.add( 2083 Ne( 2084 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2085 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2086 + 1, 2087 FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 2088 ) 2089 ) 2090 dim_constraints.add(Ne(8 * s0, 16)) 2091 dim_constraints.add( 2092 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2093 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2094 + 60 2095 >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2096 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2097 + 1 2098 ) 2099 dim_constraints.add( 2100 60 2101 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2102 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2103 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2104 + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2105 <= 2147483647 2106 ) 2107 dim_constraints.add( 2108 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2109 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2110 + 90 2111 <= 2147483647 2112 ) 2113 dim_constraints.add(FloorDiv(s0, 2) < 16) 2114 dim_constraints.add(FloorDiv(s0, 2) > 1) 2115 dim_constraints.add( 2116 Ne( 2117 90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2118 - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2119 + 90 * (FloorDiv(s0, 2)), 2120 0, 2121 ) 2122 ) 2123 dim_constraints.add( 2124 1 2125 < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2126 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2127 + 90 2128 ) 2129 dim_constraints.add( 2130 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2131 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2132 + 1 2133 > 1 2134 ) 2135 dim_constraints.add( 2136 60 2137 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2138 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2139 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2140 + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) 2141 > 1 2142 ) 2143 dim_constraints.add( 2144 Ne( 2145 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2146 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2147 + 60 * (FloorDiv(s0, 2)), 2148 0, 2149 ) 2150 ) 2151 dim_constraints.add( 2152 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2153 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2154 + 90 2155 > 1 2156 ) 2157 dim_constraints.add( 2158 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2159 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2160 + 60 2161 > 1 2162 ) 2163 dim_constraints.add( 2164 Ne( 2165 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2166 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2167 + 60 * (FloorDiv(s0, 2)), 2168 3 * (FloorDiv(s0, 2)), 2169 ) 2170 ) 2171 dim_constraints.add( 2172 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2173 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2174 + 60 * (FloorDiv(s0, 2)) 2175 > 0 2176 ) 2177 dim_constraints.add( 2178 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2179 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2180 + 60 2181 > 0 2182 ) 2183 dim_constraints.add( 2184 Ne( 2185 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2186 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2187 + 120, 2188 0, 2189 ) 2190 ) 2191 dim_constraints.add( 2192 1 2193 < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2194 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2195 + 120 2196 ) 2197 dim_constraints.add( 2198 Ne( 2199 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2200 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2201 + 120, 2202 6, 2203 ) 2204 ) 2205 dim_constraints.add( 2206 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2207 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2208 + 120 2209 > 0 2210 ) 2211 dim_constraints.add( 2212 Ne( 2213 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2214 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2215 + 120, 2216 0, 2217 ) 2218 ) 2219 dim_constraints.add( 2220 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2221 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2222 + 120 2223 <= 2147483647 2224 ) 2225 dim_constraints.add( 2226 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2227 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2228 + 120 2229 <= 20480 2230 ) 2231 dim_constraints.add( 2232 Ne( 2233 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2234 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2235 + 90, 2236 0, 2237 ) 2238 ) 2239 dim_constraints.add( 2240 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2241 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2242 + 120 2243 > 1 2244 ) 2245 dim_constraints.add( 2246 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2247 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2248 + 90 2249 <= 20480 2250 ) 2251 dim_constraints.add( 2252 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2253 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2254 + 60 2255 <= 20480 2256 ) 2257 dim_constraints.add( 2258 Ne( 2259 240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2260 - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2261 + 240, 2262 0, 2263 ) 2264 ) 2265 dim_constraints.add(Eq(6 * s5, 132)) 2266 dim_constraints.add(Eq(4, FloorDiv(s0, 2))) 2267 dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4)) 2268 dim_constraints.add( 2269 Ne( 2270 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2271 - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2272 + 64 * (FloorDiv(s0, 2)), 2273 0, 2274 ) 2275 ) 2276 dim_constraints.add( 2277 1 2278 < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2279 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2280 + 64 2281 ) 2282 dim_constraints.add( 2283 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2284 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2285 + 64 2286 <= 2147483647 2287 ) 2288 dim_constraints.add( 2289 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2290 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2291 + 64 2292 > 1 2293 ) 2294 dim_constraints.add( 2295 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2296 - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2297 + 62 2298 <= 2147483647 2299 ) 2300 dim_constraints.add( 2301 Ne( 2302 62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2303 - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2304 + 62 * (FloorDiv(s0, 2)), 2305 0, 2306 ) 2307 ) 2308 dim_constraints.add( 2309 1 2310 < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2311 - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2312 + 62 2313 ) 2314 dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) 2315 dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) 2316 dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) 2317 dim_constraints.add(Eq(4, FloorDiv(s0, 2))) 2318 dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) 2319 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3) 2320 dim_constraints.add( 2321 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2322 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2323 + 576 2324 <= 2147483647 2325 ) 2326 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0) 2327 dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1) 2328 dim_constraints.add( 2329 Ne( 2330 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2331 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2332 + 576 * (FloorDiv(s0, 2)), 2333 0, 2334 ) 2335 ) 2336 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) 2337 dim_constraints.add( 2338 Ne( 2339 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2340 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2341 + 9, 2342 1, 2343 ) 2344 ) 2345 dim_constraints.add( 2346 Ne( 2347 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2348 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2349 + 9, 2350 0, 2351 ) 2352 ) 2353 dim_constraints.add( 2354 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2355 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2356 + 9 2357 >= 0 2358 ) 2359 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0)) 2360 dim_constraints.add( 2361 1 2362 < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2363 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2364 + 576 2365 ) 2366 dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) 2367 dim_constraints.add( 2368 Ne( 2369 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2370 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2371 + 576 * (FloorDiv(s0, 2)), 2372 256, 2373 ) 2374 ) 2375 dim_constraints.add( 2376 Eq( 2377 64 2378 * ( 2379 Mod( 2380 (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2381 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2382 + 9 * (FloorDiv(s0, 2)), 2383 4, 2384 ) 2385 ), 2386 0, 2387 ) 2388 ) 2389 dim_constraints.add( 2390 Eq( 2391 FloorDiv(s0, 2), 2392 FloorDiv( 2393 ( 2394 (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2395 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2396 + 9 * (FloorDiv(s0, 2)) 2397 ), 2398 4, 2399 ), 2400 ) 2401 ) 2402 dim_constraints.add( 2403 Eq( 2404 FloorDiv( 2405 ( 2406 (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2407 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2408 + 9 * (FloorDiv(s0, 2)) 2409 ), 2410 4, 2411 ), 2412 FloorDiv(s0, 2), 2413 ) 2414 ) 2415 dim_constraints.add( 2416 Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0) 2417 ) 2418 dim_constraints.add( 2419 Eq( 2420 64 2421 * ( 2422 Mod( 2423 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2424 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2425 + 1, 2426 4, 2427 ) 2428 ), 2429 0, 2430 ) 2431 ) 2432 dim_constraints.add( 2433 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2434 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2435 + 576 * (FloorDiv(s0, 2)) 2436 > 0 2437 ) 2438 dim_constraints.add( 2439 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2440 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2441 + 9 2442 >= 1 2443 ) 2444 dim_constraints.add( 2445 Eq( 2446 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2447 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2448 + 576, 2449 256, 2450 ) 2451 ) 2452 dim_constraints.add( 2453 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2454 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2455 + 540 2456 <= 2147483647 2457 ) 2458 dim_constraints.add( 2459 Ne( 2460 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2461 - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2462 + 540 * (FloorDiv(s0, 2)), 2463 0, 2464 ) 2465 ) 2466 dim_constraints.add( 2467 1 2468 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2469 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2470 + 540 2471 ) 2472 dim_constraints.add( 2473 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2474 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2475 + 9 2476 <= 2147483647 2477 ) 2478 dim_constraints.add( 2479 Ne( 2480 (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2481 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2482 + 9 * (FloorDiv(s0, 2)), 2483 0, 2484 ) 2485 ) 2486 dim_constraints.add( 2487 1 2488 < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2489 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2490 + 9 2491 ) 2492 dim_constraints.add( 2493 (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2494 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2495 + 9 2496 > 1 2497 ) 2498 dim_constraints.add( 2499 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 2500 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) 2501 + 540 2502 > 1 2503 ) 2504 dim_constraints.add(s0 >= 2) 2505 dim_constraints.add(s1 >= 2) 2506 dim_constraints.add(s6 >= 2) 2507 dim_constraints.add(s5 >= 2) 2508 2509 dim_constraints.solve() 2510 self.assertEqual( 2511 dim_constraints._static_results, 2512 { 2513 "L['c'].size()[0] == 8", 2514 "L['d'].size()[0] == 8", 2515 "L['a'].size()[2] == 96", 2516 "L['f'].size()[1] == 1", 2517 "L['a'].size()[3] == 96", 2518 "L['b'].size()[2] == 3", 2519 "L['b'].size()[1] == 22", 2520 "L['b'].size()[0] == 8", 2521 "L['a'].size()[1] == 22", 2522 "L['a'].size()[0] == 8", 2523 }, 2524 ) 2525 self.assertEqual( 2526 dim_constraints._dynamic_results, 2527 { 2528 "2 <= L['c'].size()[1]", 2529 "L['d'].size()[1] == L['c'].size()[1]", 2530 "L['e'].size()[1] == L['c'].size()[1]", 2531 }, 2532 ) 2533 2534 2535class TestGuardsExpressions(TestCase): 2536 """ 2537 Tests the guards-related methods used by the inductor FX graph cache. 2538 """ 2539 2540 def test_guards_gt_lt(self): 2541 shape_env = ShapeEnv() 2542 s0 = create_symint(shape_env, 6) 2543 s1 = create_symint(shape_env, 7) 2544 s2 = create_symint(shape_env, 5) 2545 2546 guard_int(sym_int(s0 > 5)) 2547 guard_int(sym_int(s0 < 7)) 2548 2549 guards = shape_env.produce_guards_expression([s0]) 2550 2551 self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2552 self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) 2553 self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)])) 2554 2555 def test_guards_float_print(self): 2556 shape_env = ShapeEnv() 2557 s0 = create_symint(shape_env, 3) 2558 guard_bool(2 / s0 == 2 / 3) 2559 guards = shape_env.produce_guards_expression([s0]) 2560 self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2561 2562 def test_guards_float_div(self): 2563 shape_env = ShapeEnv() 2564 s0 = create_symint(shape_env, 8) 2565 s1 = create_symint(shape_env, 7) 2566 2567 guard_int(sym_int(s0 / 2.0)) 2568 guards = shape_env.produce_guards_expression([s0]) 2569 2570 self.assertIn("ToFloat", guards) 2571 self.assertIn("FloatTrueDiv", guards) 2572 self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) 2573 self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) 2574 2575 2576if __name__ == "__main__": 2577 run_tests() 2578