xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/core/test_scalarinherit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3""" Test printing of scalar types.
4
5"""
6import functools
7from unittest import skipIf as skipif
8
9import pytest
10
11import torch._numpy as np
12from torch._numpy.testing import assert_
13from torch.testing._internal.common_utils import run_tests, TestCase
14
15
16skip = functools.partial(skipif, True)
17
18
19class A:
20    pass
21
22
23class B(A, np.float64):
24    pass
25
26
27class C(B):
28    pass
29
30
31class D(C, B):
32    pass
33
34
35class B0(np.float64, A):
36    pass
37
38
39class C0(B0):
40    pass
41
42
43class HasNew:
44    def __new__(cls, *args, **kwargs):
45        return cls, args, kwargs
46
47
48class B1(np.float64, HasNew):
49    pass
50
51
52@skip(reason="scalar repr: numpy plans to make it more explicit")
53class TestInherit(TestCase):
54    def test_init(self):
55        x = B(1.0)
56        assert_(str(x) == "1.0")
57        y = C(2.0)
58        assert_(str(y) == "2.0")
59        z = D(3.0)
60        assert_(str(z) == "3.0")
61
62    def test_init2(self):
63        x = B0(1.0)
64        assert_(str(x) == "1.0")
65        y = C0(2.0)
66        assert_(str(y) == "2.0")
67
68    def test_gh_15395(self):
69        # HasNew is the second base, so `np.float64` should have priority
70        x = B1(1.0)
71        assert_(str(x) == "1.0")
72
73        # previously caused RecursionError!?
74        with pytest.raises(TypeError):
75            B1(1.0, 2.0)
76
77
78if __name__ == "__main__":
79    run_tests()
80