1# Owner(s): ["module: dynamo"] 2 3""" Test printing of scalar types. 4 5""" 6import functools 7from unittest import skipIf as skipif 8 9import pytest 10 11import torch._numpy as np 12from torch._numpy.testing import assert_ 13from torch.testing._internal.common_utils import run_tests, TestCase 14 15 16skip = functools.partial(skipif, True) 17 18 19class A: 20 pass 21 22 23class B(A, np.float64): 24 pass 25 26 27class C(B): 28 pass 29 30 31class D(C, B): 32 pass 33 34 35class B0(np.float64, A): 36 pass 37 38 39class C0(B0): 40 pass 41 42 43class HasNew: 44 def __new__(cls, *args, **kwargs): 45 return cls, args, kwargs 46 47 48class B1(np.float64, HasNew): 49 pass 50 51 52@skip(reason="scalar repr: numpy plans to make it more explicit") 53class TestInherit(TestCase): 54 def test_init(self): 55 x = B(1.0) 56 assert_(str(x) == "1.0") 57 y = C(2.0) 58 assert_(str(y) == "2.0") 59 z = D(3.0) 60 assert_(str(z) == "3.0") 61 62 def test_init2(self): 63 x = B0(1.0) 64 assert_(str(x) == "1.0") 65 y = C0(2.0) 66 assert_(str(y) == "2.0") 67 68 def test_gh_15395(self): 69 # HasNew is the second base, so `np.float64` should have priority 70 x = B1(1.0) 71 assert_(str(x) == "1.0") 72 73 # previously caused RecursionError!? 74 with pytest.raises(TypeError): 75 B1(1.0, 2.0) 76 77 78if __name__ == "__main__": 79 run_tests() 80