xref: /aosp_15_r20/external/pytorch/test/test_numpy_interop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: numpy"]
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport numpy as np
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
13*da0073e9SAndroid Build Coastguard Worker    dtypes,
14*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
15*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
16*da0073e9SAndroid Build Coastguard Worker    skipMeta,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# For testing handling NumPy objects and sending tensors to / accepting
23*da0073e9SAndroid Build Coastguard Worker#   arrays from NumPy.
24*da0073e9SAndroid Build Coastguard Workerclass TestNumPyInterop(TestCase):
25*da0073e9SAndroid Build Coastguard Worker    # Note: the warning this tests for only appears once per program, so
26*da0073e9SAndroid Build Coastguard Worker    # other instances of this warning should be addressed to avoid
27*da0073e9SAndroid Build Coastguard Worker    # the tests depending on the order in which they're run.
28*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
29*da0073e9SAndroid Build Coastguard Worker    def test_numpy_non_writeable(self, device):
30*da0073e9SAndroid Build Coastguard Worker        arr = np.zeros(5)
31*da0073e9SAndroid Build Coastguard Worker        arr.flags["WRITEABLE"] = False
32*da0073e9SAndroid Build Coastguard Worker        self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
35*da0073e9SAndroid Build Coastguard Worker    def test_numpy_unresizable(self, device) -> None:
36*da0073e9SAndroid Build Coastguard Worker        x = np.zeros((2, 2))
37*da0073e9SAndroid Build Coastguard Worker        y = torch.from_numpy(x)
38*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
39*da0073e9SAndroid Build Coastguard Worker            x.resize((5, 5))
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(5, 5)
42*da0073e9SAndroid Build Coastguard Worker        w = z.numpy()
43*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
44*da0073e9SAndroid Build Coastguard Worker            z.resize_(10, 10)
45*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
46*da0073e9SAndroid Build Coastguard Worker            w.resize((10, 10))
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
49*da0073e9SAndroid Build Coastguard Worker    def test_to_numpy(self, device) -> None:
50*da0073e9SAndroid Build Coastguard Worker        def get_castable_tensor(shape, dtype):
51*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
52*da0073e9SAndroid Build Coastguard Worker                dtype_info = torch.finfo(dtype)
53*da0073e9SAndroid Build Coastguard Worker                # can't directly use min and max, because for double, max - min
54*da0073e9SAndroid Build Coastguard Worker                # is greater than double range and sampling always gives inf.
55*da0073e9SAndroid Build Coastguard Worker                low = max(dtype_info.min, -1e10)
56*da0073e9SAndroid Build Coastguard Worker                high = min(dtype_info.max, 1e10)
57*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
58*da0073e9SAndroid Build Coastguard Worker            else:
59*da0073e9SAndroid Build Coastguard Worker                # can't directly use min and max, because for int64_t, max - min
60*da0073e9SAndroid Build Coastguard Worker                # is greater than int64_t range and triggers UB.
61*da0073e9SAndroid Build Coastguard Worker                low = max(torch.iinfo(dtype).min, int(-1e10))
62*da0073e9SAndroid Build Coastguard Worker                high = min(torch.iinfo(dtype).max, int(1e10))
63*da0073e9SAndroid Build Coastguard Worker                t = torch.empty(shape, dtype=torch.int64).random_(low, high)
64*da0073e9SAndroid Build Coastguard Worker            return t.to(dtype)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker        dtypes = [
67*da0073e9SAndroid Build Coastguard Worker            torch.uint8,
68*da0073e9SAndroid Build Coastguard Worker            torch.int8,
69*da0073e9SAndroid Build Coastguard Worker            torch.short,
70*da0073e9SAndroid Build Coastguard Worker            torch.int,
71*da0073e9SAndroid Build Coastguard Worker            torch.half,
72*da0073e9SAndroid Build Coastguard Worker            torch.float,
73*da0073e9SAndroid Build Coastguard Worker            torch.double,
74*da0073e9SAndroid Build Coastguard Worker            torch.long,
75*da0073e9SAndroid Build Coastguard Worker        ]
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        for dtp in dtypes:
78*da0073e9SAndroid Build Coastguard Worker            # 1D
79*da0073e9SAndroid Build Coastguard Worker            sz = 10
80*da0073e9SAndroid Build Coastguard Worker            x = get_castable_tensor(sz, dtp)
81*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
82*da0073e9SAndroid Build Coastguard Worker            for i in range(sz):
83*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x[i], y[i])
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker            # 1D > 0 storage offset
86*da0073e9SAndroid Build Coastguard Worker            xm = get_castable_tensor(sz * 2, dtp)
87*da0073e9SAndroid Build Coastguard Worker            x = xm.narrow(0, sz - 1, sz)
88*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.storage_offset() > 0)
89*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
90*da0073e9SAndroid Build Coastguard Worker            for i in range(sz):
91*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x[i], y[i])
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            def check2d(x, y):
94*da0073e9SAndroid Build Coastguard Worker                for i in range(sz1):
95*da0073e9SAndroid Build Coastguard Worker                    for j in range(sz2):
96*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(x[i][j], y[i][j])
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker            # empty
99*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([]).to(dtp)
100*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
101*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.size, 0)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            # contiguous 2D
104*da0073e9SAndroid Build Coastguard Worker            sz1 = 3
105*da0073e9SAndroid Build Coastguard Worker            sz2 = 5
106*da0073e9SAndroid Build Coastguard Worker            x = get_castable_tensor((sz1, sz2), dtp)
107*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
108*da0073e9SAndroid Build Coastguard Worker            check2d(x, y)
109*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.flags["C_CONTIGUOUS"])
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker            # with storage offset
112*da0073e9SAndroid Build Coastguard Worker            xm = get_castable_tensor((sz1 * 2, sz2), dtp)
113*da0073e9SAndroid Build Coastguard Worker            x = xm.narrow(0, sz1 - 1, sz1)
114*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
115*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.storage_offset() > 0)
116*da0073e9SAndroid Build Coastguard Worker            check2d(x, y)
117*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.flags["C_CONTIGUOUS"])
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker            # non-contiguous 2D
120*da0073e9SAndroid Build Coastguard Worker            x = get_castable_tensor((sz2, sz1), dtp).t()
121*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
122*da0073e9SAndroid Build Coastguard Worker            check2d(x, y)
123*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.flags["C_CONTIGUOUS"])
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker            # with storage offset
126*da0073e9SAndroid Build Coastguard Worker            xm = get_castable_tensor((sz2 * 2, sz1), dtp)
127*da0073e9SAndroid Build Coastguard Worker            x = xm.narrow(0, sz2 - 1, sz2).t()
128*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
129*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.storage_offset() > 0)
130*da0073e9SAndroid Build Coastguard Worker            check2d(x, y)
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker            # non-contiguous 2D with holes
133*da0073e9SAndroid Build Coastguard Worker            xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp)
134*da0073e9SAndroid Build Coastguard Worker            x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t()
135*da0073e9SAndroid Build Coastguard Worker            y = x.numpy()
136*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.storage_offset() > 0)
137*da0073e9SAndroid Build Coastguard Worker            check2d(x, y)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker            if dtp != torch.half:
140*da0073e9SAndroid Build Coastguard Worker                # check writeable
141*da0073e9SAndroid Build Coastguard Worker                x = get_castable_tensor((3, 4), dtp)
142*da0073e9SAndroid Build Coastguard Worker                y = x.numpy()
143*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(y.flags.writeable)
144*da0073e9SAndroid Build Coastguard Worker                y[0][1] = 3
145*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(x[0][1] == 3)
146*da0073e9SAndroid Build Coastguard Worker                y = x.t().numpy()
147*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(y.flags.writeable)
148*da0073e9SAndroid Build Coastguard Worker                y[0][1] = 3
149*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(x[0][1] == 3)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_to_numpy_bool(self, device) -> None:
152*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([True, False], dtype=torch.bool)
153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.bool)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        y = x.numpy()
156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, np.bool_)
157*da0073e9SAndroid Build Coastguard Worker        for i in range(len(x)):
158*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[i], y[i])
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([True], dtype=torch.bool)
161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.dtype, torch.bool)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        y = x.numpy()
164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.dtype, np.bool_)
165*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0], y[0])
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet")
168*da0073e9SAndroid Build Coastguard Worker    def test_to_numpy_force_argument(self, device) -> None:
169*da0073e9SAndroid Build Coastguard Worker        for force in [False, True]:
170*da0073e9SAndroid Build Coastguard Worker            for requires_grad in [False, True]:
171*da0073e9SAndroid Build Coastguard Worker                for sparse in [False, True]:
172*da0073e9SAndroid Build Coastguard Worker                    for conj in [False, True]:
173*da0073e9SAndroid Build Coastguard Worker                        data = [[1 + 2j, -2 + 3j], [-1 - 2j, 3 - 2j]]
174*da0073e9SAndroid Build Coastguard Worker                        x = torch.tensor(
175*da0073e9SAndroid Build Coastguard Worker                            data, requires_grad=requires_grad, device=device
176*da0073e9SAndroid Build Coastguard Worker                        )
177*da0073e9SAndroid Build Coastguard Worker                        y = x
178*da0073e9SAndroid Build Coastguard Worker                        if sparse:
179*da0073e9SAndroid Build Coastguard Worker                            if requires_grad:
180*da0073e9SAndroid Build Coastguard Worker                                continue
181*da0073e9SAndroid Build Coastguard Worker                            x = x.to_sparse()
182*da0073e9SAndroid Build Coastguard Worker                        if conj:
183*da0073e9SAndroid Build Coastguard Worker                            x = x.conj()
184*da0073e9SAndroid Build Coastguard Worker                            y = x.resolve_conj()
185*da0073e9SAndroid Build Coastguard Worker                        expect_error = (
186*da0073e9SAndroid Build Coastguard Worker                            requires_grad or sparse or conj or not device == "cpu"
187*da0073e9SAndroid Build Coastguard Worker                        )
188*da0073e9SAndroid Build Coastguard Worker                        error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
189*da0073e9SAndroid Build Coastguard Worker                        if not force and expect_error:
190*da0073e9SAndroid Build Coastguard Worker                            self.assertRaisesRegex(
191*da0073e9SAndroid Build Coastguard Worker                                (RuntimeError, TypeError), error_msg, lambda: x.numpy()
192*da0073e9SAndroid Build Coastguard Worker                            )
193*da0073e9SAndroid Build Coastguard Worker                            self.assertRaisesRegex(
194*da0073e9SAndroid Build Coastguard Worker                                (RuntimeError, TypeError),
195*da0073e9SAndroid Build Coastguard Worker                                error_msg,
196*da0073e9SAndroid Build Coastguard Worker                                lambda: x.numpy(force=False),
197*da0073e9SAndroid Build Coastguard Worker                            )
198*da0073e9SAndroid Build Coastguard Worker                        elif force and sparse:
199*da0073e9SAndroid Build Coastguard Worker                            self.assertRaisesRegex(
200*da0073e9SAndroid Build Coastguard Worker                                TypeError, error_msg, lambda: x.numpy(force=True)
201*da0073e9SAndroid Build Coastguard Worker                            )
202*da0073e9SAndroid Build Coastguard Worker                        else:
203*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(x.numpy(force=force), y)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    def test_from_numpy(self, device) -> None:
206*da0073e9SAndroid Build Coastguard Worker        dtypes = [
207*da0073e9SAndroid Build Coastguard Worker            np.double,
208*da0073e9SAndroid Build Coastguard Worker            np.float64,
209*da0073e9SAndroid Build Coastguard Worker            np.float16,
210*da0073e9SAndroid Build Coastguard Worker            np.complex64,
211*da0073e9SAndroid Build Coastguard Worker            np.complex128,
212*da0073e9SAndroid Build Coastguard Worker            np.int64,
213*da0073e9SAndroid Build Coastguard Worker            np.int32,
214*da0073e9SAndroid Build Coastguard Worker            np.int16,
215*da0073e9SAndroid Build Coastguard Worker            np.int8,
216*da0073e9SAndroid Build Coastguard Worker            np.uint8,
217*da0073e9SAndroid Build Coastguard Worker            np.longlong,
218*da0073e9SAndroid Build Coastguard Worker            np.bool_,
219*da0073e9SAndroid Build Coastguard Worker        ]
220*da0073e9SAndroid Build Coastguard Worker        complex_dtypes = [
221*da0073e9SAndroid Build Coastguard Worker            np.complex64,
222*da0073e9SAndroid Build Coastguard Worker            np.complex128,
223*da0073e9SAndroid Build Coastguard Worker        ]
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
226*da0073e9SAndroid Build Coastguard Worker            array = np.array([1, 2, 3, 4], dtype=dtype)
227*da0073e9SAndroid Build Coastguard Worker            tensor_from_array = torch.from_numpy(array)
228*da0073e9SAndroid Build Coastguard Worker            # TODO: change to tensor equality check once HalfTensor
229*da0073e9SAndroid Build Coastguard Worker            # implements `==`
230*da0073e9SAndroid Build Coastguard Worker            for i in range(len(array)):
231*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tensor_from_array[i], array[i])
232*da0073e9SAndroid Build Coastguard Worker            # ufunc 'remainder' not supported for complex dtypes
233*da0073e9SAndroid Build Coastguard Worker            if dtype not in complex_dtypes:
234*da0073e9SAndroid Build Coastguard Worker                # This is a special test case for Windows
235*da0073e9SAndroid Build Coastguard Worker                # https://github.com/pytorch/pytorch/issues/22615
236*da0073e9SAndroid Build Coastguard Worker                array2 = array % 2
237*da0073e9SAndroid Build Coastguard Worker                tensor_from_array2 = torch.from_numpy(array2)
238*da0073e9SAndroid Build Coastguard Worker                for i in range(len(array2)):
239*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(tensor_from_array2[i], array2[i])
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        # Test unsupported type
242*da0073e9SAndroid Build Coastguard Worker        array = np.array(["foo", "bar"], dtype=np.dtype(np.str_))
243*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
244*da0073e9SAndroid Build Coastguard Worker            tensor_from_array = torch.from_numpy(array)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        # check storage offset
247*da0073e9SAndroid Build Coastguard Worker        x = np.linspace(1, 125, 125)
248*da0073e9SAndroid Build Coastguard Worker        x.shape = (5, 5, 5)
249*da0073e9SAndroid Build Coastguard Worker        x = x[1]
250*da0073e9SAndroid Build Coastguard Worker        expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1]
251*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.from_numpy(x), expected)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        # check noncontiguous
254*da0073e9SAndroid Build Coastguard Worker        x = np.linspace(1, 25, 25)
255*da0073e9SAndroid Build Coastguard Worker        x.shape = (5, 5)
256*da0073e9SAndroid Build Coastguard Worker        expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t()
257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.from_numpy(x.T), expected)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        # check noncontiguous with holes
260*da0073e9SAndroid Build Coastguard Worker        x = np.linspace(1, 125, 125)
261*da0073e9SAndroid Build Coastguard Worker        x.shape = (5, 5, 5)
262*da0073e9SAndroid Build Coastguard Worker        x = x[:, 1]
263*da0073e9SAndroid Build Coastguard Worker        expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1]
264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.from_numpy(x), expected)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        # check zero dimensional
267*da0073e9SAndroid Build Coastguard Worker        x = np.zeros((0, 2))
268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.from_numpy(x).shape, (0, 2))
269*da0073e9SAndroid Build Coastguard Worker        x = np.zeros((2, 0))
270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.from_numpy(x).shape, (2, 0))
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        # check ill-sized strides raise exception
273*da0073e9SAndroid Build Coastguard Worker        x = np.array([3.0, 5.0, 8.0])
274*da0073e9SAndroid Build Coastguard Worker        x.strides = (3,)
275*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(ValueError, lambda: torch.from_numpy(x))
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.")
278*da0073e9SAndroid Build Coastguard Worker    def test_from_numpy_no_leak_on_invalid_dtype(self):
279*da0073e9SAndroid Build Coastguard Worker        # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary
280*da0073e9SAndroid Build Coastguard Worker        # object. See https://github.com/pytorch/pytorch/issues/121138
281*da0073e9SAndroid Build Coastguard Worker        x = np.array("value".encode("ascii"))
282*da0073e9SAndroid Build Coastguard Worker        for _ in range(1000):
283*da0073e9SAndroid Build Coastguard Worker            try:
284*da0073e9SAndroid Build Coastguard Worker                torch.from_numpy(x)
285*da0073e9SAndroid Build Coastguard Worker            except TypeError:
286*da0073e9SAndroid Build Coastguard Worker                pass
287*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sys.getrefcount(x) == 2)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    @skipMeta
290*da0073e9SAndroid Build Coastguard Worker    def test_from_list_of_ndarray_warning(self, device):
291*da0073e9SAndroid Build Coastguard Worker        warning_msg = (
292*da0073e9SAndroid Build Coastguard Worker            r"Creating a tensor from a list of numpy.ndarrays is extremely slow"
293*da0073e9SAndroid Build Coastguard Worker        )
294*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(UserWarning, warning_msg):
295*da0073e9SAndroid Build Coastguard Worker            torch.tensor([np.array([0]), np.array([1])], device=device)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_ctor_with_invalid_numpy_array_sequence(self, device):
298*da0073e9SAndroid Build Coastguard Worker        # Invalid list of numpy array
299*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
300*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
301*da0073e9SAndroid Build Coastguard Worker                [np.random.random(size=(3, 3)), np.random.random(size=(3, 0))],
302*da0073e9SAndroid Build Coastguard Worker                device=device,
303*da0073e9SAndroid Build Coastguard Worker            )
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        # Invalid list of list of numpy array
306*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
307*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
308*da0073e9SAndroid Build Coastguard Worker                [[np.random.random(size=(3, 3)), np.random.random(size=(3, 2))]],
309*da0073e9SAndroid Build Coastguard Worker                device=device,
310*da0073e9SAndroid Build Coastguard Worker            )
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
313*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
314*da0073e9SAndroid Build Coastguard Worker                [
315*da0073e9SAndroid Build Coastguard Worker                    [np.random.random(size=(3, 3)), np.random.random(size=(3, 3))],
316*da0073e9SAndroid Build Coastguard Worker                    [np.random.random(size=(3, 3)), np.random.random(size=(3, 2))],
317*da0073e9SAndroid Build Coastguard Worker                ],
318*da0073e9SAndroid Build Coastguard Worker                device=device,
319*da0073e9SAndroid Build Coastguard Worker            )
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        # expected shape is `[1, 2, 3]`, hence we try to iterate over 0-D array
322*da0073e9SAndroid Build Coastguard Worker        # leading to type error : not a sequence.
323*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "not a sequence"):
324*da0073e9SAndroid Build Coastguard Worker            torch.tensor(
325*da0073e9SAndroid Build Coastguard Worker                [[np.random.random(size=(3)), np.random.random()]], device=device
326*da0073e9SAndroid Build Coastguard Worker            )
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        # list of list or numpy array.
329*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "expected sequence of length"):
330*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[1, 2, 3], np.random.random(size=(2,))], device=device)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
333*da0073e9SAndroid Build Coastguard Worker    def test_ctor_with_numpy_scalar_ctor(self, device) -> None:
334*da0073e9SAndroid Build Coastguard Worker        dtypes = [
335*da0073e9SAndroid Build Coastguard Worker            np.double,
336*da0073e9SAndroid Build Coastguard Worker            np.float64,
337*da0073e9SAndroid Build Coastguard Worker            np.float16,
338*da0073e9SAndroid Build Coastguard Worker            np.int64,
339*da0073e9SAndroid Build Coastguard Worker            np.int32,
340*da0073e9SAndroid Build Coastguard Worker            np.int16,
341*da0073e9SAndroid Build Coastguard Worker            np.uint8,
342*da0073e9SAndroid Build Coastguard Worker            np.bool_,
343*da0073e9SAndroid Build Coastguard Worker        ]
344*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
345*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
348*da0073e9SAndroid Build Coastguard Worker    def test_numpy_index(self, device):
349*da0073e9SAndroid Build Coastguard Worker        i = np.array([0, 1, 2], dtype=np.int32)
350*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
351*da0073e9SAndroid Build Coastguard Worker        for idx in i:
352*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(isinstance(idx, int))
353*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x[idx], x[int(idx)])
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
356*da0073e9SAndroid Build Coastguard Worker    def test_numpy_index_multi(self, device):
357*da0073e9SAndroid Build Coastguard Worker        for dim_sz in [2, 8, 16, 32]:
358*da0073e9SAndroid Build Coastguard Worker            i = np.zeros((dim_sz, dim_sz, dim_sz), dtype=np.int32)
359*da0073e9SAndroid Build Coastguard Worker            i[: dim_sz // 2, :, :] = 1
360*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(dim_sz, dim_sz, dim_sz)
361*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x[i == 1].numel() == np.sum(i))
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
364*da0073e9SAndroid Build Coastguard Worker    def test_numpy_array_interface(self, device):
365*da0073e9SAndroid Build Coastguard Worker        types = [
366*da0073e9SAndroid Build Coastguard Worker            torch.DoubleTensor,
367*da0073e9SAndroid Build Coastguard Worker            torch.FloatTensor,
368*da0073e9SAndroid Build Coastguard Worker            torch.HalfTensor,
369*da0073e9SAndroid Build Coastguard Worker            torch.LongTensor,
370*da0073e9SAndroid Build Coastguard Worker            torch.IntTensor,
371*da0073e9SAndroid Build Coastguard Worker            torch.ShortTensor,
372*da0073e9SAndroid Build Coastguard Worker            torch.ByteTensor,
373*da0073e9SAndroid Build Coastguard Worker        ]
374*da0073e9SAndroid Build Coastguard Worker        dtypes = [
375*da0073e9SAndroid Build Coastguard Worker            np.float64,
376*da0073e9SAndroid Build Coastguard Worker            np.float32,
377*da0073e9SAndroid Build Coastguard Worker            np.float16,
378*da0073e9SAndroid Build Coastguard Worker            np.int64,
379*da0073e9SAndroid Build Coastguard Worker            np.int32,
380*da0073e9SAndroid Build Coastguard Worker            np.int16,
381*da0073e9SAndroid Build Coastguard Worker            np.uint8,
382*da0073e9SAndroid Build Coastguard Worker        ]
383*da0073e9SAndroid Build Coastguard Worker        for tp, dtype in zip(types, dtypes):
384*da0073e9SAndroid Build Coastguard Worker            # Only concrete class can be given where "Type[number[_64Bit]]" is expected
385*da0073e9SAndroid Build Coastguard Worker            if np.dtype(dtype).kind == "u":  # type: ignore[misc]
386*da0073e9SAndroid Build Coastguard Worker                # .type expects a XxxTensor, which have no type hints on
387*da0073e9SAndroid Build Coastguard Worker                # purpose, so ignore during mypy type checking
388*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
389*da0073e9SAndroid Build Coastguard Worker                array = np.array([1, 2, 3, 4], dtype=dtype)
390*da0073e9SAndroid Build Coastguard Worker            else:
391*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([1, -2, 3, -4]).type(tp)  # type: ignore[call-overload]
392*da0073e9SAndroid Build Coastguard Worker                array = np.array([1, -2, 3, -4], dtype=dtype)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker            # Test __array__ w/o dtype argument
395*da0073e9SAndroid Build Coastguard Worker            asarray = np.asarray(x)
396*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(asarray, np.ndarray)
397*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(asarray.dtype, dtype)
398*da0073e9SAndroid Build Coastguard Worker            for i in range(len(x)):
399*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(asarray[i], x[i])
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker            # Test __array_wrap__, same dtype
402*da0073e9SAndroid Build Coastguard Worker            abs_x = np.abs(x)
403*da0073e9SAndroid Build Coastguard Worker            abs_array = np.abs(array)
404*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(abs_x, tp)
405*da0073e9SAndroid Build Coastguard Worker            for i in range(len(x)):
406*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(abs_x[i], abs_array[i])
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker        # Test __array__ with dtype argument
409*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
410*da0073e9SAndroid Build Coastguard Worker            x = torch.IntTensor([1, -2, 3, -4])
411*da0073e9SAndroid Build Coastguard Worker            asarray = np.asarray(x, dtype=dtype)
412*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(asarray.dtype, dtype)
413*da0073e9SAndroid Build Coastguard Worker            # Only concrete class can be given where "Type[number[_64Bit]]" is expected
414*da0073e9SAndroid Build Coastguard Worker            if np.dtype(dtype).kind == "u":  # type: ignore[misc]
415*da0073e9SAndroid Build Coastguard Worker                wrapped_x = np.array([1, -2, 3, -4], dtype=dtype)
416*da0073e9SAndroid Build Coastguard Worker                for i in range(len(x)):
417*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(asarray[i], wrapped_x[i])
418*da0073e9SAndroid Build Coastguard Worker            else:
419*da0073e9SAndroid Build Coastguard Worker                for i in range(len(x)):
420*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(asarray[i], x[i])
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # Test some math functions with float types
423*da0073e9SAndroid Build Coastguard Worker        float_types = [torch.DoubleTensor, torch.FloatTensor]
424*da0073e9SAndroid Build Coastguard Worker        float_dtypes = [np.float64, np.float32]
425*da0073e9SAndroid Build Coastguard Worker        for tp, dtype in zip(float_types, float_dtypes):
426*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
427*da0073e9SAndroid Build Coastguard Worker            array = np.array([1, 2, 3, 4], dtype=dtype)
428*da0073e9SAndroid Build Coastguard Worker            for func in ["sin", "sqrt", "ceil"]:
429*da0073e9SAndroid Build Coastguard Worker                ufunc = getattr(np, func)
430*da0073e9SAndroid Build Coastguard Worker                res_x = ufunc(x)
431*da0073e9SAndroid Build Coastguard Worker                res_array = ufunc(array)
432*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(res_x, tp)
433*da0073e9SAndroid Build Coastguard Worker                for i in range(len(x)):
434*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res_x[i], res_array[i])
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        # Test functions with boolean return value
437*da0073e9SAndroid Build Coastguard Worker        for tp, dtype in zip(types, dtypes):
438*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2, 3, 4]).type(tp)  # type: ignore[call-overload]
439*da0073e9SAndroid Build Coastguard Worker            array = np.array([1, 2, 3, 4], dtype=dtype)
440*da0073e9SAndroid Build Coastguard Worker            geq2_x = np.greater_equal(x, 2)
441*da0073e9SAndroid Build Coastguard Worker            geq2_array = np.greater_equal(array, 2).astype("uint8")
442*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(geq2_x, torch.ByteTensor)
443*da0073e9SAndroid Build Coastguard Worker            for i in range(len(x)):
444*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(geq2_x[i], geq2_array[i])
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
447*da0073e9SAndroid Build Coastguard Worker    def test_multiplication_numpy_scalar(self, device) -> None:
448*da0073e9SAndroid Build Coastguard Worker        for np_dtype in [
449*da0073e9SAndroid Build Coastguard Worker            np.float32,
450*da0073e9SAndroid Build Coastguard Worker            np.float64,
451*da0073e9SAndroid Build Coastguard Worker            np.int32,
452*da0073e9SAndroid Build Coastguard Worker            np.int64,
453*da0073e9SAndroid Build Coastguard Worker            np.int16,
454*da0073e9SAndroid Build Coastguard Worker            np.uint8,
455*da0073e9SAndroid Build Coastguard Worker        ]:
456*da0073e9SAndroid Build Coastguard Worker            for t_dtype in [torch.float, torch.double]:
457*da0073e9SAndroid Build Coastguard Worker                # mypy raises an error when np.floatXY(2.0) is called
458*da0073e9SAndroid Build Coastguard Worker                # even though this is valid code
459*da0073e9SAndroid Build Coastguard Worker                np_sc = np_dtype(2.0)  # type: ignore[abstract, arg-type]
460*da0073e9SAndroid Build Coastguard Worker                t = torch.ones(2, requires_grad=True, dtype=t_dtype)
461*da0073e9SAndroid Build Coastguard Worker                r1 = t * np_sc
462*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(r1, torch.Tensor)
463*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(r1.dtype == t_dtype)
464*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(r1.requires_grad)
465*da0073e9SAndroid Build Coastguard Worker                r2 = np_sc * t
466*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(r2, torch.Tensor)
467*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(r2.dtype == t_dtype)
468*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(r2.requires_grad)
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
471*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
472*da0073e9SAndroid Build Coastguard Worker    def test_parse_numpy_int_overflow(self, device):
473*da0073e9SAndroid Build Coastguard Worker        # assertRaises uses a try-except which dynamo has issues with
474*da0073e9SAndroid Build Coastguard Worker        # Only concrete class can be given where "Type[number[_64Bit]]" is expected
475*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
476*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
477*da0073e9SAndroid Build Coastguard Worker            "(Overflow|an integer is required)",
478*da0073e9SAndroid Build Coastguard Worker            lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
479*da0073e9SAndroid Build Coastguard Worker        )  # type: ignore[call-overload]
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
482*da0073e9SAndroid Build Coastguard Worker    def test_parse_numpy_int(self, device):
483*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/29252
484*da0073e9SAndroid Build Coastguard Worker        for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
485*da0073e9SAndroid Build Coastguard Worker            scalar = 3
486*da0073e9SAndroid Build Coastguard Worker            np_arr = np.array([scalar], dtype=nptype)
487*da0073e9SAndroid Build Coastguard Worker            np_val = np_arr[0]
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker            # np integral type can be treated as a python int in native functions with
490*da0073e9SAndroid Build Coastguard Worker            # int parameters:
491*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val))
492*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
493*da0073e9SAndroid Build Coastguard Worker                torch.ones([2, 2, 2, 2]).mean(scalar),
494*da0073e9SAndroid Build Coastguard Worker                torch.ones([2, 2, 2, 2]).mean(np_val),
495*da0073e9SAndroid Build Coastguard Worker            )
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker            # numpy integral type parses like a python int in custom python bindings:
498*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.Storage(np_val).size(), scalar)  # type: ignore[attr-defined]
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker            tensor = torch.tensor([2], dtype=torch.int)
501*da0073e9SAndroid Build Coastguard Worker            tensor[0] = np_val
502*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor[0], np_val)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker            # Original reported issue, np integral type parses to the correct
505*da0073e9SAndroid Build Coastguard Worker            # PyTorch integral type when passed for a `Scalar` parameter in
506*da0073e9SAndroid Build Coastguard Worker            # arithmetic operations:
507*da0073e9SAndroid Build Coastguard Worker            t = torch.from_numpy(np_arr)
508*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((t + np_val).dtype, t.dtype)
509*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((np_val + t).dtype, t.dtype)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def test_has_storage_numpy(self, device):
512*da0073e9SAndroid Build Coastguard Worker        for dtype in [np.float32, np.float64, np.int64, np.int32, np.int16, np.uint8]:
513*da0073e9SAndroid Build Coastguard Worker            arr = np.array([1], dtype=dtype)
514*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
515*da0073e9SAndroid Build Coastguard Worker                torch.tensor(arr, device=device, dtype=torch.float32).storage()
516*da0073e9SAndroid Build Coastguard Worker            )
517*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
518*da0073e9SAndroid Build Coastguard Worker                torch.tensor(arr, device=device, dtype=torch.double).storage()
519*da0073e9SAndroid Build Coastguard Worker            )
520*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
521*da0073e9SAndroid Build Coastguard Worker                torch.tensor(arr, device=device, dtype=torch.int).storage()
522*da0073e9SAndroid Build Coastguard Worker            )
523*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
524*da0073e9SAndroid Build Coastguard Worker                torch.tensor(arr, device=device, dtype=torch.long).storage()
525*da0073e9SAndroid Build Coastguard Worker            )
526*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
527*da0073e9SAndroid Build Coastguard Worker                torch.tensor(arr, device=device, dtype=torch.uint8).storage()
528*da0073e9SAndroid Build Coastguard Worker            )
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
531*da0073e9SAndroid Build Coastguard Worker    def test_numpy_scalar_cmp(self, device, dtype):
532*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
533*da0073e9SAndroid Build Coastguard Worker            tensors = (
534*da0073e9SAndroid Build Coastguard Worker                torch.tensor(complex(1, 3), dtype=dtype, device=device),
535*da0073e9SAndroid Build Coastguard Worker                torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device),
536*da0073e9SAndroid Build Coastguard Worker                torch.tensor(
537*da0073e9SAndroid Build Coastguard Worker                    [[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device
538*da0073e9SAndroid Build Coastguard Worker                ),
539*da0073e9SAndroid Build Coastguard Worker            )
540*da0073e9SAndroid Build Coastguard Worker        else:
541*da0073e9SAndroid Build Coastguard Worker            tensors = (
542*da0073e9SAndroid Build Coastguard Worker                torch.tensor(3, dtype=dtype, device=device),
543*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1, 0, -3], dtype=dtype, device=device),
544*da0073e9SAndroid Build Coastguard Worker                torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device),
545*da0073e9SAndroid Build Coastguard Worker            )
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        for tensor in tensors:
548*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
549*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(TypeError):
550*da0073e9SAndroid Build Coastguard Worker                    np_array = tensor.cpu().numpy()
551*da0073e9SAndroid Build Coastguard Worker                continue
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker            np_array = tensor.cpu().numpy()
554*da0073e9SAndroid Build Coastguard Worker            for t, a in product(
555*da0073e9SAndroid Build Coastguard Worker                (tensor.flatten()[0], tensor.flatten()[0].item()),
556*da0073e9SAndroid Build Coastguard Worker                (np_array.flatten()[0], np_array.flatten()[0].item()),
557*da0073e9SAndroid Build Coastguard Worker            ):
558*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t, a)
559*da0073e9SAndroid Build Coastguard Worker                if (
560*da0073e9SAndroid Build Coastguard Worker                    dtype == torch.complex64
561*da0073e9SAndroid Build Coastguard Worker                    and torch.is_tensor(t)
562*da0073e9SAndroid Build Coastguard Worker                    and type(a) == np.complex64
563*da0073e9SAndroid Build Coastguard Worker                ):
564*da0073e9SAndroid Build Coastguard Worker                    # TODO: Imaginary part is dropped in this case. Need fix.
565*da0073e9SAndroid Build Coastguard Worker                    # https://github.com/pytorch/pytorch/issues/43579
566*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(t == a)
567*da0073e9SAndroid Build Coastguard Worker                else:
568*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(t == a)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
571*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
572*da0073e9SAndroid Build Coastguard Worker    def test___eq__(self, device, dtype):
573*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
574*da0073e9SAndroid Build Coastguard Worker        b = a.clone().detach()
575*da0073e9SAndroid Build Coastguard Worker        b_np = b.numpy()
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        # Check all elements equal
578*da0073e9SAndroid Build Coastguard Worker        res_check = torch.ones_like(a, dtype=torch.bool)
579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a == b_np, res_check)
580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_np == a, res_check)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        # Check one element unequal
583*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool:
584*da0073e9SAndroid Build Coastguard Worker            b[1][3] = not b[1][3]
585*da0073e9SAndroid Build Coastguard Worker        else:
586*da0073e9SAndroid Build Coastguard Worker            b[1][3] += 1
587*da0073e9SAndroid Build Coastguard Worker        res_check[1][3] = False
588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a == b_np, res_check)
589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_np == a, res_check)
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        # Check random elements unequal
592*da0073e9SAndroid Build Coastguard Worker        rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
593*da0073e9SAndroid Build Coastguard Worker        res_check = rand.logical_not()
594*da0073e9SAndroid Build Coastguard Worker        b.copy_(a)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool:
597*da0073e9SAndroid Build Coastguard Worker            b[rand] = b[rand].logical_not()
598*da0073e9SAndroid Build Coastguard Worker        else:
599*da0073e9SAndroid Build Coastguard Worker            b[rand] += 1
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a == b_np, res_check)
602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_np == a, res_check)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        # Check all elements unequal
605*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bool:
606*da0073e9SAndroid Build Coastguard Worker            b.copy_(a.logical_not())
607*da0073e9SAndroid Build Coastguard Worker        else:
608*da0073e9SAndroid Build Coastguard Worker            b.copy_(a + 1)
609*da0073e9SAndroid Build Coastguard Worker        res_check.fill_(False)
610*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a == b_np, res_check)
611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_np == a, res_check)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
614*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensors_interop(self, device):
615*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((), dtype=torch.float16)
616*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(np.random.rand(0), dtype=torch.float16)
617*da0073e9SAndroid Build Coastguard Worker        # Same can be achieved by running
618*da0073e9SAndroid Build Coastguard Worker        # y = torch.empty_strided((0,), (0,), dtype=torch.float16)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/115068
621*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.true_divide(x, y).shape, y.shape)
622*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/115066
623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mul(x, y).shape, y.shape)
624*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/113037
625*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.div(x, y, rounding_mode="floor").shape, y.shape)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNumPyInterop, globals())
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
631*da0073e9SAndroid Build Coastguard Worker    run_tests()
632