xref: /aosp_15_r20/external/pytorch/test/torch_np/numpy_tests/lib/test_shape_base_.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import functools
4import sys
5from unittest import expectedFailure as xfail, skipIf as skipif
6
7from pytest import raises as assert_raises
8
9from torch.testing._internal.common_utils import (
10    instantiate_parametrized_tests,
11    parametrize,
12    run_tests,
13    TEST_WITH_TORCHDYNAMO,
14    TestCase,
15    xfailIfTorchDynamo,
16    xpassIfTorchDynamo,
17)
18
19
20# If we are going to trace through these, we should use NumPy
21# If testing on eager mode, we use torch._numpy
22if TEST_WITH_TORCHDYNAMO:
23    import numpy as np
24    from numpy import (
25        apply_along_axis,
26        array_split,
27        column_stack,
28        dsplit,
29        dstack,
30        expand_dims,
31        hsplit,
32        kron,
33        put_along_axis,
34        split,
35        take_along_axis,
36        tile,
37        vsplit,
38    )
39    from numpy.random import rand, randint
40    from numpy.testing import assert_, assert_array_equal, assert_equal
41
42else:
43    import torch._numpy as np
44    from torch._numpy import (
45        array_split,
46        column_stack,
47        dsplit,
48        dstack,
49        expand_dims,
50        hsplit,
51        kron,
52        put_along_axis,
53        split,
54        take_along_axis,
55        tile,
56        vsplit,
57    )
58    from torch._numpy.random import rand, randint
59    from torch._numpy.testing import assert_, assert_array_equal, assert_equal
60
61
62skip = functools.partial(skipif, True)
63
64
65IS_64BIT = sys.maxsize > 2**32
66
67
68def _add_keepdims(func):
69    """hack in keepdims behavior into a function taking an axis"""
70
71    @functools.wraps(func)
72    def wrapped(a, axis, **kwargs):
73        res = func(a, axis=axis, **kwargs)
74        if axis is None:
75            axis = 0  # res is now a scalar, so we can insert this anywhere
76        return np.expand_dims(res, axis=axis)
77
78    return wrapped
79
80
81class TestTakeAlongAxis(TestCase):
82    def test_argequivalent(self):
83        """Test it translates from arg<func> to <func>"""
84        a = rand(3, 4, 5)
85
86        funcs = [
87            (np.sort, np.argsort, {}),
88            (_add_keepdims(np.min), _add_keepdims(np.argmin), {}),
89            (_add_keepdims(np.max), _add_keepdims(np.argmax), {}),
90            #  FIXME           (np.partition, np.argpartition, dict(kth=2)),
91        ]
92
93        for func, argfunc, kwargs in funcs:
94            for axis in list(range(a.ndim)) + [None]:
95                a_func = func(a, axis=axis, **kwargs)
96                ai_func = argfunc(a, axis=axis, **kwargs)
97                assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
98
99    def test_invalid(self):
100        """Test it errors when indices has too few dimensions"""
101        a = np.ones((10, 10))
102        ai = np.ones((10, 2), dtype=np.intp)
103
104        # sanity check
105        take_along_axis(a, ai, axis=1)
106
107        # not enough indices
108        assert_raises(
109            (ValueError, RuntimeError), take_along_axis, a, np.array(1), axis=1
110        )
111        # bool arrays not allowed
112        assert_raises(
113            (IndexError, RuntimeError), take_along_axis, a, ai.astype(bool), axis=1
114        )
115        # float arrays not allowed
116        assert_raises(
117            (IndexError, RuntimeError), take_along_axis, a, ai.astype(float), axis=1
118        )
119        # invalid axis
120        assert_raises(np.AxisError, take_along_axis, a, ai, axis=10)
121
122    def test_empty(self):
123        """Test everything is ok with empty results, even with inserted dims"""
124        a = np.ones((3, 4, 5))
125        ai = np.ones((3, 0, 5), dtype=np.intp)
126
127        actual = take_along_axis(a, ai, axis=1)
128        assert_equal(actual.shape, ai.shape)
129
130    def test_broadcast(self):
131        """Test that non-indexing dimensions are broadcast in both directions"""
132        a = np.ones((3, 4, 1))
133        ai = np.ones((1, 2, 5), dtype=np.intp)
134        actual = take_along_axis(a, ai, axis=1)
135        assert_equal(actual.shape, (3, 2, 5))
136
137
138class TestPutAlongAxis(TestCase):
139    def test_replace_max(self):
140        a_base = np.array([[10, 30, 20], [60, 40, 50]])
141
142        for axis in list(range(a_base.ndim)) + [None]:
143            # we mutate this in the loop
144            a = a_base.copy()
145
146            # replace the max with a small value
147            i_max = _add_keepdims(np.argmax)(a, axis=axis)
148            put_along_axis(a, i_max, -99, axis=axis)
149
150            # find the new minimum, which should max
151            i_min = _add_keepdims(np.argmin)(a, axis=axis)
152
153            assert_equal(i_min, i_max)
154
155    @xpassIfTorchDynamo  # (
156    # reason="RuntimeError: Expected index [1, 2, 5] to be smaller than self [3, 4, 1] apart from dimension 1")
157    def test_broadcast(self):
158        """Test that non-indexing dimensions are broadcast in both directions"""
159        a = np.ones((3, 4, 1))
160        ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4
161        put_along_axis(a, ai, 20, axis=1)
162        assert_equal(take_along_axis(a, ai, axis=1), 20)
163
164
165@xpassIfTorchDynamo  # (reason="apply_along_axis not implemented")
166class TestApplyAlongAxis(TestCase):
167    def test_simple(self):
168        a = np.ones((20, 10), "d")
169        assert_array_equal(apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1]))
170
171    def test_simple101(self):
172        a = np.ones((10, 101), "d")
173        assert_array_equal(apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1]))
174
175    def test_3d(self):
176        a = np.arange(27).reshape((3, 3, 3))
177        assert_array_equal(
178            apply_along_axis(np.sum, 0, a), [[27, 30, 33], [36, 39, 42], [45, 48, 51]]
179        )
180
181    def test_scalar_array(self, cls=np.ndarray):
182        a = np.ones((6, 3)).view(cls)
183        res = apply_along_axis(np.sum, 0, a)
184        assert_(isinstance(res, cls))
185        assert_array_equal(res, np.array([6, 6, 6]).view(cls))
186
187    def test_0d_array(self, cls=np.ndarray):
188        def sum_to_0d(x):
189            """Sum x, returning a 0d array of the same class"""
190            assert_equal(x.ndim, 1)
191            return np.squeeze(np.sum(x, keepdims=True))
192
193        a = np.ones((6, 3)).view(cls)
194        res = apply_along_axis(sum_to_0d, 0, a)
195        assert_(isinstance(res, cls))
196        assert_array_equal(res, np.array([6, 6, 6]).view(cls))
197
198        res = apply_along_axis(sum_to_0d, 1, a)
199        assert_(isinstance(res, cls))
200        assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls))
201
202    def test_axis_insertion(self, cls=np.ndarray):
203        def f1to2(x):
204            """produces an asymmetric non-square matrix from x"""
205            assert_equal(x.ndim, 1)
206            return (x[::-1] * x[1:, None]).view(cls)
207
208        a2d = np.arange(6 * 3).reshape((6, 3))
209
210        # 2d insertion along first axis
211        actual = apply_along_axis(f1to2, 0, a2d)
212        expected = np.stack(
213            [f1to2(a2d[:, i]) for i in range(a2d.shape[1])], axis=-1
214        ).view(cls)
215        assert_equal(type(actual), type(expected))
216        assert_equal(actual, expected)
217
218        # 2d insertion along last axis
219        actual = apply_along_axis(f1to2, 1, a2d)
220        expected = np.stack(
221            [f1to2(a2d[i, :]) for i in range(a2d.shape[0])], axis=0
222        ).view(cls)
223        assert_equal(type(actual), type(expected))
224        assert_equal(actual, expected)
225
226        # 3d insertion along middle axis
227        a3d = np.arange(6 * 5 * 3).reshape((6, 5, 3))
228
229        actual = apply_along_axis(f1to2, 1, a3d)
230        expected = np.stack(
231            [
232                np.stack([f1to2(a3d[i, :, j]) for i in range(a3d.shape[0])], axis=0)
233                for j in range(a3d.shape[2])
234            ],
235            axis=-1,
236        ).view(cls)
237        assert_equal(type(actual), type(expected))
238        assert_equal(actual, expected)
239
240    def test_axis_insertion_ma(self):
241        def f1to2(x):
242            """produces an asymmetric non-square matrix from x"""
243            assert_equal(x.ndim, 1)
244            res = x[::-1] * x[1:, None]
245            return np.ma.masked_where(res % 5 == 0, res)
246
247        a = np.arange(6 * 3).reshape((6, 3))
248        res = apply_along_axis(f1to2, 0, a)
249        assert_(isinstance(res, np.ma.masked_array))
250        assert_equal(res.ndim, 3)
251        assert_array_equal(res[:, :, 0].mask, f1to2(a[:, 0]).mask)
252        assert_array_equal(res[:, :, 1].mask, f1to2(a[:, 1]).mask)
253        assert_array_equal(res[:, :, 2].mask, f1to2(a[:, 2]).mask)
254
255    def test_tuple_func1d(self):
256        def sample_1d(x):
257            return x[1], x[0]
258
259        res = np.apply_along_axis(sample_1d, 1, np.array([[1, 2], [3, 4]]))
260        assert_array_equal(res, np.array([[2, 1], [4, 3]]))
261
262    def test_empty(self):
263        # can't apply_along_axis when there's no chance to call the function
264        def never_call(x):
265            assert_(False)  # should never be reached
266
267        a = np.empty((0, 0))
268        assert_raises(ValueError, np.apply_along_axis, never_call, 0, a)
269        assert_raises(ValueError, np.apply_along_axis, never_call, 1, a)
270
271        # but it's sometimes ok with some non-zero dimensions
272        def empty_to_1(x):
273            assert_(len(x) == 0)
274            return 1
275
276        a = np.empty((10, 0))
277        actual = np.apply_along_axis(empty_to_1, 1, a)
278        assert_equal(actual, np.ones(10))
279        assert_raises(ValueError, np.apply_along_axis, empty_to_1, 0, a)
280
281    @skip  # TypeError: descriptor 'union' for 'set' objects doesn't apply to a 'numpy.int64' object
282    def test_with_iterable_object(self):
283        # from issue 5248
284        d = np.array([[{1, 11}, {2, 22}, {3, 33}], [{4, 44}, {5, 55}, {6, 66}]])
285        actual = np.apply_along_axis(lambda a: set.union(*a), 0, d)
286        expected = np.array([{1, 11, 4, 44}, {2, 22, 5, 55}, {3, 33, 6, 66}])
287
288        assert_equal(actual, expected)
289
290        # issue 8642 - assert_equal doesn't detect this!
291        for i in np.ndindex(actual.shape):
292            assert_equal(type(actual[i]), type(expected[i]))
293
294
295@xfail  # (reason="apply_over_axes not implemented")
296class TestApplyOverAxes(TestCase):
297    def test_simple(self):
298        a = np.arange(24).reshape(2, 3, 4)
299        aoa_a = apply_over_axes(np.sum, a, [0, 2])
300        assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
301
302
303class TestExpandDims(TestCase):
304    def test_functionality(self):
305        s = (2, 3, 4, 5)
306        a = np.empty(s)
307        for axis in range(-5, 4):
308            b = expand_dims(a, axis)
309            assert_(b.shape[axis] == 1)
310            assert_(np.squeeze(b).shape == s)
311
312    def test_axis_tuple(self):
313        a = np.empty((3, 3, 3))
314        assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
315        assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1)
316        assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1)
317        assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
318
319    def test_axis_out_of_range(self):
320        s = (2, 3, 4, 5)
321        a = np.empty(s)
322        assert_raises(np.AxisError, expand_dims, a, -6)
323        assert_raises(np.AxisError, expand_dims, a, 5)
324
325        a = np.empty((3, 3, 3))
326        assert_raises(np.AxisError, expand_dims, a, (0, -6))
327        assert_raises(np.AxisError, expand_dims, a, (0, 5))
328
329    def test_repeated_axis(self):
330        a = np.empty((3, 3, 3))
331        assert_raises(ValueError, expand_dims, a, axis=(1, 1))
332
333
334class TestArraySplit(TestCase):
335    def test_integer_0_split(self):
336        a = np.arange(10)
337        assert_raises(ValueError, array_split, a, 0)
338
339    def test_integer_split(self):
340        a = np.arange(10)
341        res = array_split(a, 1)
342        desired = [np.arange(10)]
343        compare_results(res, desired)
344
345        res = array_split(a, 2)
346        desired = [np.arange(5), np.arange(5, 10)]
347        compare_results(res, desired)
348
349        res = array_split(a, 3)
350        desired = [np.arange(4), np.arange(4, 7), np.arange(7, 10)]
351        compare_results(res, desired)
352
353        res = array_split(a, 4)
354        desired = [np.arange(3), np.arange(3, 6), np.arange(6, 8), np.arange(8, 10)]
355        compare_results(res, desired)
356
357        res = array_split(a, 5)
358        desired = [
359            np.arange(2),
360            np.arange(2, 4),
361            np.arange(4, 6),
362            np.arange(6, 8),
363            np.arange(8, 10),
364        ]
365        compare_results(res, desired)
366
367        res = array_split(a, 6)
368        desired = [
369            np.arange(2),
370            np.arange(2, 4),
371            np.arange(4, 6),
372            np.arange(6, 8),
373            np.arange(8, 9),
374            np.arange(9, 10),
375        ]
376        compare_results(res, desired)
377
378        res = array_split(a, 7)
379        desired = [
380            np.arange(2),
381            np.arange(2, 4),
382            np.arange(4, 6),
383            np.arange(6, 7),
384            np.arange(7, 8),
385            np.arange(8, 9),
386            np.arange(9, 10),
387        ]
388        compare_results(res, desired)
389
390        res = array_split(a, 8)
391        desired = [
392            np.arange(2),
393            np.arange(2, 4),
394            np.arange(4, 5),
395            np.arange(5, 6),
396            np.arange(6, 7),
397            np.arange(7, 8),
398            np.arange(8, 9),
399            np.arange(9, 10),
400        ]
401        compare_results(res, desired)
402
403        res = array_split(a, 9)
404        desired = [
405            np.arange(2),
406            np.arange(2, 3),
407            np.arange(3, 4),
408            np.arange(4, 5),
409            np.arange(5, 6),
410            np.arange(6, 7),
411            np.arange(7, 8),
412            np.arange(8, 9),
413            np.arange(9, 10),
414        ]
415        compare_results(res, desired)
416
417        res = array_split(a, 10)
418        desired = [
419            np.arange(1),
420            np.arange(1, 2),
421            np.arange(2, 3),
422            np.arange(3, 4),
423            np.arange(4, 5),
424            np.arange(5, 6),
425            np.arange(6, 7),
426            np.arange(7, 8),
427            np.arange(8, 9),
428            np.arange(9, 10),
429        ]
430        compare_results(res, desired)
431
432        res = array_split(a, 11)
433        desired = [
434            np.arange(1),
435            np.arange(1, 2),
436            np.arange(2, 3),
437            np.arange(3, 4),
438            np.arange(4, 5),
439            np.arange(5, 6),
440            np.arange(6, 7),
441            np.arange(7, 8),
442            np.arange(8, 9),
443            np.arange(9, 10),
444            np.array([]),
445        ]
446        compare_results(res, desired)
447
448    def test_integer_split_2D_rows(self):
449        a = np.array([np.arange(10), np.arange(10)])
450        res = array_split(a, 3, axis=0)
451        tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]), np.zeros((0, 10))]
452        compare_results(res, tgt)
453        assert_(a.dtype.type is res[-1].dtype.type)
454
455        # Same thing for manual splits:
456        res = array_split(a, [0, 1], axis=0)
457        tgt = [np.zeros((0, 10)), np.array([np.arange(10)]), np.array([np.arange(10)])]
458        compare_results(res, tgt)
459        assert_(a.dtype.type is res[-1].dtype.type)
460
461    def test_integer_split_2D_cols(self):
462        a = np.array([np.arange(10), np.arange(10)])
463        res = array_split(a, 3, axis=-1)
464        desired = [
465            np.array([np.arange(4), np.arange(4)]),
466            np.array([np.arange(4, 7), np.arange(4, 7)]),
467            np.array([np.arange(7, 10), np.arange(7, 10)]),
468        ]
469        compare_results(res, desired)
470
471    def test_integer_split_2D_default(self):
472        """This will fail if we change default axis"""
473        a = np.array([np.arange(10), np.arange(10)])
474        res = array_split(a, 3)
475        tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]), np.zeros((0, 10))]
476        compare_results(res, tgt)
477        assert_(a.dtype.type is res[-1].dtype.type)
478        # perhaps should check higher dimensions
479
480    @skipif(not IS_64BIT, reason="Needs 64bit platform")
481    def test_integer_split_2D_rows_greater_max_int32(self):
482        a = np.broadcast_to([0], (1 << 32, 2))
483        res = array_split(a, 4)
484        chunk = np.broadcast_to([0], (1 << 30, 2))
485        tgt = [chunk] * 4
486        for i in range(len(tgt)):
487            assert_equal(res[i].shape, tgt[i].shape)
488
489    def test_index_split_simple(self):
490        a = np.arange(10)
491        indices = [1, 5, 7]
492        res = array_split(a, indices, axis=-1)
493        desired = [np.arange(0, 1), np.arange(1, 5), np.arange(5, 7), np.arange(7, 10)]
494        compare_results(res, desired)
495
496    def test_index_split_low_bound(self):
497        a = np.arange(10)
498        indices = [0, 5, 7]
499        res = array_split(a, indices, axis=-1)
500        desired = [np.array([]), np.arange(0, 5), np.arange(5, 7), np.arange(7, 10)]
501        compare_results(res, desired)
502
503    def test_index_split_high_bound(self):
504        a = np.arange(10)
505        indices = [0, 5, 7, 10, 12]
506        res = array_split(a, indices, axis=-1)
507        desired = [
508            np.array([]),
509            np.arange(0, 5),
510            np.arange(5, 7),
511            np.arange(7, 10),
512            np.array([]),
513            np.array([]),
514        ]
515        compare_results(res, desired)
516
517
518class TestSplit(TestCase):
519    # The split function is essentially the same as array_split,
520    # except that it test if splitting will result in an
521    # equal split.  Only test for this case.
522
523    def test_equal_split(self):
524        a = np.arange(10)
525        res = split(a, 2)
526        desired = [np.arange(5), np.arange(5, 10)]
527        compare_results(res, desired)
528
529    def test_unequal_split(self):
530        a = np.arange(10)
531        assert_raises(ValueError, split, a, 3)
532
533
534class TestColumnStack(TestCase):
535    def test_non_iterable(self):
536        assert_raises(TypeError, column_stack, 1)
537
538    def test_1D_arrays(self):
539        # example from docstring
540        a = np.array((1, 2, 3))
541        b = np.array((2, 3, 4))
542        expected = np.array([[1, 2], [2, 3], [3, 4]])
543        actual = np.column_stack((a, b))
544        assert_equal(actual, expected)
545
546    def test_2D_arrays(self):
547        # same as hstack 2D docstring example
548        a = np.array([[1], [2], [3]])
549        b = np.array([[2], [3], [4]])
550        expected = np.array([[1, 2], [2, 3], [3, 4]])
551        actual = np.column_stack((a, b))
552        assert_equal(actual, expected)
553
554    def test_generator(self):
555        # numpy 1.24 emits a warning but we don't
556        # with assert_warns(FutureWarning):
557        column_stack([np.arange(3) for _ in range(2)])
558
559
560class TestDstack(TestCase):
561    def test_non_iterable(self):
562        assert_raises(TypeError, dstack, 1)
563
564    def test_0D_array(self):
565        a = np.array(1)
566        b = np.array(2)
567        res = dstack([a, b])
568        desired = np.array([[[1, 2]]])
569        assert_array_equal(res, desired)
570
571    def test_1D_array(self):
572        a = np.array([1])
573        b = np.array([2])
574        res = dstack([a, b])
575        desired = np.array([[[1, 2]]])
576        assert_array_equal(res, desired)
577
578    def test_2D_array(self):
579        a = np.array([[1], [2]])
580        b = np.array([[1], [2]])
581        res = dstack([a, b])
582        desired = np.array(
583            [
584                [[1, 1]],
585                [
586                    [
587                        2,
588                        2,
589                    ]
590                ],
591            ]
592        )
593        assert_array_equal(res, desired)
594
595    def test_2D_array2(self):
596        a = np.array([1, 2])
597        b = np.array([1, 2])
598        res = dstack([a, b])
599        desired = np.array([[[1, 1], [2, 2]]])
600        assert_array_equal(res, desired)
601
602    def test_generator(self):
603        # numpy 1.24 emits a warning but we don't
604        # with assert_warns(FutureWarning):
605        dstack([np.arange(3) for _ in range(2)])
606
607
608# array_split has more comprehensive test of splitting.
609# only do simple test on hsplit, vsplit, and dsplit
610class TestHsplit(TestCase):
611    """Only testing for integer splits."""
612
613    def test_non_iterable(self):
614        assert_raises(ValueError, hsplit, 1, 1)
615
616    def test_0D_array(self):
617        a = np.array(1)
618        try:
619            hsplit(a, 2)
620            assert_(0)
621        except ValueError:
622            pass
623
624    def test_1D_array(self):
625        a = np.array([1, 2, 3, 4])
626        res = hsplit(a, 2)
627        desired = [np.array([1, 2]), np.array([3, 4])]
628        compare_results(res, desired)
629
630    def test_2D_array(self):
631        a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]])
632        res = hsplit(a, 2)
633        desired = [np.array([[1, 2], [1, 2]]), np.array([[3, 4], [3, 4]])]
634        compare_results(res, desired)
635
636
637class TestVsplit(TestCase):
638    """Only testing for integer splits."""
639
640    def test_non_iterable(self):
641        assert_raises(ValueError, vsplit, 1, 1)
642
643    def test_0D_array(self):
644        a = np.array(1)
645        assert_raises(ValueError, vsplit, a, 2)
646
647    def test_1D_array(self):
648        a = np.array([1, 2, 3, 4])
649        try:
650            vsplit(a, 2)
651            assert_(0)
652        except ValueError:
653            pass
654
655    def test_2D_array(self):
656        a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]])
657        res = vsplit(a, 2)
658        desired = [np.array([[1, 2, 3, 4]]), np.array([[1, 2, 3, 4]])]
659        compare_results(res, desired)
660
661
662class TestDsplit(TestCase):
663    # Only testing for integer splits.
664    def test_non_iterable(self):
665        assert_raises(ValueError, dsplit, 1, 1)
666
667    def test_0D_array(self):
668        a = np.array(1)
669        assert_raises(ValueError, dsplit, a, 2)
670
671    def test_1D_array(self):
672        a = np.array([1, 2, 3, 4])
673        assert_raises(ValueError, dsplit, a, 2)
674
675    def test_2D_array(self):
676        a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]])
677        try:
678            dsplit(a, 2)
679            assert_(0)
680        except ValueError:
681            pass
682
683    def test_3D_array(self):
684        a = np.array([[[1, 2, 3, 4], [1, 2, 3, 4]], [[1, 2, 3, 4], [1, 2, 3, 4]]])
685        res = dsplit(a, 2)
686        desired = [
687            np.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]),
688            np.array([[[3, 4], [3, 4]], [[3, 4], [3, 4]]]),
689        ]
690        compare_results(res, desired)
691
692
693class TestSqueeze(TestCase):
694    def test_basic(self):
695        a = rand(20, 10, 10, 1, 1)
696        b = rand(20, 1, 10, 1, 20)
697        c = rand(1, 1, 20, 10)
698        assert_array_equal(np.squeeze(a), np.reshape(a, (20, 10, 10)))
699        assert_array_equal(np.squeeze(b), np.reshape(b, (20, 10, 20)))
700        assert_array_equal(np.squeeze(c), np.reshape(c, (20, 10)))
701
702        # Squeezing to 0-dim should still give an ndarray
703        a = [[[1.5]]]
704        res = np.squeeze(a)
705        assert_equal(res, 1.5)
706        assert_equal(res.ndim, 0)
707        assert type(res) is np.ndarray
708
709    @xfailIfTorchDynamo
710    def test_basic_2(self):
711        aa = np.ones((3, 1, 4, 1, 1))
712        assert aa.squeeze().tensor._base is aa.tensor
713
714    def test_squeeze_axis(self):
715        A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]
716        assert_equal(np.squeeze(A).shape, (3, 3))
717        assert_equal(np.squeeze(A, axis=()), A)
718
719        assert_equal(np.squeeze(np.zeros((1, 3, 1))).shape, (3,))
720        assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=0).shape, (3, 1))
721        assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=-1).shape, (1, 3))
722        assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=2).shape, (1, 3))
723        assert_equal(np.squeeze([np.zeros((3, 1))]).shape, (3,))
724        assert_equal(np.squeeze([np.zeros((3, 1))], axis=0).shape, (3, 1))
725        assert_equal(np.squeeze([np.zeros((3, 1))], axis=2).shape, (1, 3))
726        assert_equal(np.squeeze([np.zeros((3, 1))], axis=-1).shape, (1, 3))
727
728    def test_squeeze_type(self):
729        # Ticket #133
730        a = np.array([3])
731        b = np.array(3)
732        assert type(a.squeeze()) is np.ndarray
733        assert type(b.squeeze()) is np.ndarray
734
735    @skip(reason="XXX: order='F' not implemented")
736    def test_squeeze_contiguous(self):
737        # Similar to GitHub issue #387
738        a = np.zeros((1, 2)).squeeze()
739        b = np.zeros((2, 2, 2), order="F")[:, :, ::2].squeeze()
740        assert_(a.flags.c_contiguous)
741        assert_(a.flags.f_contiguous)
742        assert_(b.flags.f_contiguous)
743
744    @xpassIfTorchDynamo  # (reason="XXX: noop in torch, while numpy raises")
745    def test_squeeze_axis_handling(self):
746        with assert_raises(ValueError):
747            np.squeeze(np.array([[1], [2], [3]]), axis=0)
748
749
750@instantiate_parametrized_tests
751class TestKron(TestCase):
752    def test_basic(self):
753        # Using 0-dimensional ndarray
754        a = np.array(1)
755        b = np.array([[1, 2], [3, 4]])
756        k = np.array([[1, 2], [3, 4]])
757        assert_array_equal(np.kron(a, b), k)
758        a = np.array([[1, 2], [3, 4]])
759        b = np.array(1)
760        assert_array_equal(np.kron(a, b), k)
761
762        # Using 1-dimensional ndarray
763        a = np.array([3])
764        b = np.array([[1, 2], [3, 4]])
765        k = np.array([[3, 6], [9, 12]])
766        assert_array_equal(np.kron(a, b), k)
767        a = np.array([[1, 2], [3, 4]])
768        b = np.array([3])
769        assert_array_equal(np.kron(a, b), k)
770
771        # Using 3-dimensional ndarray
772        a = np.array([[[1]], [[2]]])
773        b = np.array([[1, 2], [3, 4]])
774        k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
775        assert_array_equal(np.kron(a, b), k)
776        a = np.array([[1, 2], [3, 4]])
777        b = np.array([[[1]], [[2]]])
778        k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
779        assert_array_equal(np.kron(a, b), k)
780
781    @skip(reason="NP_VER: fails on CI")
782    @parametrize(
783        "shape_a,shape_b",
784        [
785            ((1, 1), (1, 1)),
786            ((1, 2, 3), (4, 5, 6)),
787            ((2, 2), (2, 2, 2)),
788            ((1, 0), (1, 1)),
789            ((2, 0, 2), (2, 2)),
790            ((2, 0, 0, 2), (2, 0, 2)),
791        ],
792    )
793    def test_kron_shape(self, shape_a, shape_b):
794        a = np.ones(shape_a)
795        b = np.ones(shape_b)
796        normalised_shape_a = (1,) * max(0, len(shape_b) - len(shape_a)) + shape_a
797        normalised_shape_b = (1,) * max(0, len(shape_a) - len(shape_b)) + shape_b
798        expected_shape = np.multiply(normalised_shape_a, normalised_shape_b)
799
800        k = np.kron(a, b)
801        assert np.array_equal(k.shape, expected_shape), "Unexpected shape from kron"
802
803
804class TestTile(TestCase):
805    def test_basic(self):
806        a = np.array([0, 1, 2])
807        b = [[1, 2], [3, 4]]
808        assert_equal(tile(a, 2), [0, 1, 2, 0, 1, 2])
809        assert_equal(tile(a, (2, 2)), [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]])
810        assert_equal(tile(a, (1, 2)), [[0, 1, 2, 0, 1, 2]])
811        assert_equal(tile(b, 2), [[1, 2, 1, 2], [3, 4, 3, 4]])
812        assert_equal(tile(b, (2, 1)), [[1, 2], [3, 4], [1, 2], [3, 4]])
813        assert_equal(
814            tile(b, (2, 2)), [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]]
815        )
816
817    def test_tile_one_repetition_on_array_gh4679(self):
818        a = np.arange(5)
819        b = tile(a, 1)
820        b += 2
821        assert_equal(a, np.arange(5))
822
823    def test_empty(self):
824        a = np.array([[[]]])
825        b = np.array([[], []])
826        c = tile(b, 2).shape
827        d = tile(a, (3, 2, 5)).shape
828        assert_equal(c, (2, 0))
829        assert_equal(d, (3, 2, 0))
830
831    def test_kroncompare(self):
832        reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
833        shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
834        for s in shape:
835            b = randint(0, 10, size=s)
836            for r in reps:
837                a = np.ones(r, b.dtype)
838                large = tile(b, r)
839                klarge = kron(a, b)
840                assert_equal(large, klarge)
841
842
843@xfail  # Maybe implement one day
844class TestMayShareMemory(TestCase):
845    def test_basic(self):
846        d = np.ones((50, 60))
847        d2 = np.ones((30, 60, 6))
848        assert_(np.may_share_memory(d, d))
849        assert_(np.may_share_memory(d, d[::-1]))
850        assert_(np.may_share_memory(d, d[::2]))
851        assert_(np.may_share_memory(d, d[1:, ::-1]))
852
853        assert_(not np.may_share_memory(d[::-1], d2))
854        assert_(not np.may_share_memory(d[::2], d2))
855        assert_(not np.may_share_memory(d[1:, ::-1], d2))
856        assert_(np.may_share_memory(d2[1:, ::-1], d2))
857
858
859# Utility
860def compare_results(res, desired):
861    """Compare lists of arrays."""
862    if len(res) != len(desired):
863        raise ValueError("Iterables have different lengths")
864    # See also PEP 618 for Python 3.10
865    for x, y in zip(res, desired):
866        assert_array_equal(x, y)
867
868
869if __name__ == "__main__":
870    run_tests()
871