xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/core/test_scalar_ctors.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3"""
4Test the scalar constructors, which also do type-coercion
5"""
6import functools
7from unittest import skipIf as skipif
8
9import pytest
10
11from torch.testing._internal.common_utils import (
12    instantiate_parametrized_tests,
13    parametrize,
14    run_tests,
15    subtest,
16    TEST_WITH_TORCHDYNAMO,
17    TestCase,
18    xpassIfTorchDynamo,
19)
20
21
22if TEST_WITH_TORCHDYNAMO:
23    import numpy as np
24    from numpy.testing import assert_almost_equal, assert_equal
25else:
26    import torch._numpy as np
27    from torch._numpy.testing import assert_almost_equal, assert_equal
28
29
30skip = functools.partial(skipif, True)
31
32
33class TestFromString(TestCase):
34    @xpassIfTorchDynamo  # (reason="XXX: floats from strings")
35    def test_floating(self):
36        # Ticket #640, floats from string
37        fsingle = np.single("1.234")
38        fdouble = np.double("1.234")
39        assert_almost_equal(fsingle, 1.234)
40        assert_almost_equal(fdouble, 1.234)
41
42    @xpassIfTorchDynamo  # (reason="XXX: floats from strings")
43    def test_floating_overflow(self):
44        """Strings containing an unrepresentable float overflow"""
45        fhalf = np.half("1e10000")
46        assert_equal(fhalf, np.inf)
47        fsingle = np.single("1e10000")
48        assert_equal(fsingle, np.inf)
49        fdouble = np.double("1e10000")
50        assert_equal(fdouble, np.inf)
51
52        fhalf = np.half("-1e10000")
53        assert_equal(fhalf, -np.inf)
54        fsingle = np.single("-1e10000")
55        assert_equal(fsingle, -np.inf)
56        fdouble = np.double("-1e10000")
57        assert_equal(fdouble, -np.inf)
58
59    def test_bool(self):
60        with pytest.raises(TypeError):
61            np.bool_(False, garbage=True)
62
63
64class TestFromInt(TestCase):
65    def test_intp(self):
66        # Ticket #99
67        assert_equal(1024, np.intp(1024))
68
69    def test_uint64_from_negative(self):
70        # NumPy test was asserting a DeprecationWarning
71        assert_equal(np.uint8(-2), np.uint8(254))
72
73
74int_types = [
75    subtest(np.byte, name="np_byte"),
76    subtest(np.short, name="np_short"),
77    subtest(np.intc, name="np_intc"),
78    subtest(np.int_, name="np_int_"),
79    subtest(np.longlong, name="np_longlong"),
80]
81uint_types = [np.ubyte]
82float_types = [np.half, np.single, np.double]
83cfloat_types = [np.csingle, np.cdouble]
84
85
86@instantiate_parametrized_tests
87class TestArrayFromScalar(TestCase):
88    """gh-15467"""
89
90    def _do_test(self, t1, t2):
91        x = t1(2)
92        arr = np.array(x, dtype=t2)
93        # type should be preserved exactly
94        if t2 is None:
95            assert arr.dtype.type is t1
96        else:
97            assert arr.dtype.type is t2
98
99        arr1 = np.asarray(x, dtype=t2)
100        if t2 is None:
101            assert arr1.dtype.type is t1
102        else:
103            assert arr1.dtype.type is t2
104
105    @parametrize("t1", int_types + uint_types)
106    @parametrize("t2", int_types + uint_types + [None])
107    def test_integers(self, t1, t2):
108        return self._do_test(t1, t2)
109
110    @parametrize("t1", float_types)
111    @parametrize("t2", float_types + [None])
112    def test_reals(self, t1, t2):
113        return self._do_test(t1, t2)
114
115    @parametrize("t1", cfloat_types)
116    @parametrize("t2", cfloat_types + [None])
117    def test_complex(self, t1, t2):
118        return self._do_test(t1, t2)
119
120
121if __name__ == "__main__":
122    run_tests()
123