1# Owner(s): ["module: dynamo"] 2 3import functools 4import sys 5from unittest import expectedFailure as xfail, skipIf as skipif 6 7from pytest import raises as assert_raises 8 9from torch.testing._internal.common_utils import ( 10 instantiate_parametrized_tests, 11 parametrize, 12 run_tests, 13 TEST_WITH_TORCHDYNAMO, 14 TestCase, 15 xfailIfTorchDynamo, 16 xpassIfTorchDynamo, 17) 18 19 20# If we are going to trace through these, we should use NumPy 21# If testing on eager mode, we use torch._numpy 22if TEST_WITH_TORCHDYNAMO: 23 import numpy as np 24 from numpy import ( 25 apply_along_axis, 26 array_split, 27 column_stack, 28 dsplit, 29 dstack, 30 expand_dims, 31 hsplit, 32 kron, 33 put_along_axis, 34 split, 35 take_along_axis, 36 tile, 37 vsplit, 38 ) 39 from numpy.random import rand, randint 40 from numpy.testing import assert_, assert_array_equal, assert_equal 41 42else: 43 import torch._numpy as np 44 from torch._numpy import ( 45 array_split, 46 column_stack, 47 dsplit, 48 dstack, 49 expand_dims, 50 hsplit, 51 kron, 52 put_along_axis, 53 split, 54 take_along_axis, 55 tile, 56 vsplit, 57 ) 58 from torch._numpy.random import rand, randint 59 from torch._numpy.testing import assert_, assert_array_equal, assert_equal 60 61 62skip = functools.partial(skipif, True) 63 64 65IS_64BIT = sys.maxsize > 2**32 66 67 68def _add_keepdims(func): 69 """hack in keepdims behavior into a function taking an axis""" 70 71 @functools.wraps(func) 72 def wrapped(a, axis, **kwargs): 73 res = func(a, axis=axis, **kwargs) 74 if axis is None: 75 axis = 0 # res is now a scalar, so we can insert this anywhere 76 return np.expand_dims(res, axis=axis) 77 78 return wrapped 79 80 81class TestTakeAlongAxis(TestCase): 82 def test_argequivalent(self): 83 """Test it translates from arg<func> to <func>""" 84 a = rand(3, 4, 5) 85 86 funcs = [ 87 (np.sort, np.argsort, {}), 88 (_add_keepdims(np.min), _add_keepdims(np.argmin), {}), 89 (_add_keepdims(np.max), _add_keepdims(np.argmax), {}), 90 # FIXME (np.partition, np.argpartition, dict(kth=2)), 91 ] 92 93 for func, argfunc, kwargs in funcs: 94 for axis in list(range(a.ndim)) + [None]: 95 a_func = func(a, axis=axis, **kwargs) 96 ai_func = argfunc(a, axis=axis, **kwargs) 97 assert_equal(a_func, take_along_axis(a, ai_func, axis=axis)) 98 99 def test_invalid(self): 100 """Test it errors when indices has too few dimensions""" 101 a = np.ones((10, 10)) 102 ai = np.ones((10, 2), dtype=np.intp) 103 104 # sanity check 105 take_along_axis(a, ai, axis=1) 106 107 # not enough indices 108 assert_raises( 109 (ValueError, RuntimeError), take_along_axis, a, np.array(1), axis=1 110 ) 111 # bool arrays not allowed 112 assert_raises( 113 (IndexError, RuntimeError), take_along_axis, a, ai.astype(bool), axis=1 114 ) 115 # float arrays not allowed 116 assert_raises( 117 (IndexError, RuntimeError), take_along_axis, a, ai.astype(float), axis=1 118 ) 119 # invalid axis 120 assert_raises(np.AxisError, take_along_axis, a, ai, axis=10) 121 122 def test_empty(self): 123 """Test everything is ok with empty results, even with inserted dims""" 124 a = np.ones((3, 4, 5)) 125 ai = np.ones((3, 0, 5), dtype=np.intp) 126 127 actual = take_along_axis(a, ai, axis=1) 128 assert_equal(actual.shape, ai.shape) 129 130 def test_broadcast(self): 131 """Test that non-indexing dimensions are broadcast in both directions""" 132 a = np.ones((3, 4, 1)) 133 ai = np.ones((1, 2, 5), dtype=np.intp) 134 actual = take_along_axis(a, ai, axis=1) 135 assert_equal(actual.shape, (3, 2, 5)) 136 137 138class TestPutAlongAxis(TestCase): 139 def test_replace_max(self): 140 a_base = np.array([[10, 30, 20], [60, 40, 50]]) 141 142 for axis in list(range(a_base.ndim)) + [None]: 143 # we mutate this in the loop 144 a = a_base.copy() 145 146 # replace the max with a small value 147 i_max = _add_keepdims(np.argmax)(a, axis=axis) 148 put_along_axis(a, i_max, -99, axis=axis) 149 150 # find the new minimum, which should max 151 i_min = _add_keepdims(np.argmin)(a, axis=axis) 152 153 assert_equal(i_min, i_max) 154 155 @xpassIfTorchDynamo # ( 156 # reason="RuntimeError: Expected index [1, 2, 5] to be smaller than self [3, 4, 1] apart from dimension 1") 157 def test_broadcast(self): 158 """Test that non-indexing dimensions are broadcast in both directions""" 159 a = np.ones((3, 4, 1)) 160 ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4 161 put_along_axis(a, ai, 20, axis=1) 162 assert_equal(take_along_axis(a, ai, axis=1), 20) 163 164 165@xpassIfTorchDynamo # (reason="apply_along_axis not implemented") 166class TestApplyAlongAxis(TestCase): 167 def test_simple(self): 168 a = np.ones((20, 10), "d") 169 assert_array_equal(apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1])) 170 171 def test_simple101(self): 172 a = np.ones((10, 101), "d") 173 assert_array_equal(apply_along_axis(len, 0, a), len(a) * np.ones(a.shape[1])) 174 175 def test_3d(self): 176 a = np.arange(27).reshape((3, 3, 3)) 177 assert_array_equal( 178 apply_along_axis(np.sum, 0, a), [[27, 30, 33], [36, 39, 42], [45, 48, 51]] 179 ) 180 181 def test_scalar_array(self, cls=np.ndarray): 182 a = np.ones((6, 3)).view(cls) 183 res = apply_along_axis(np.sum, 0, a) 184 assert_(isinstance(res, cls)) 185 assert_array_equal(res, np.array([6, 6, 6]).view(cls)) 186 187 def test_0d_array(self, cls=np.ndarray): 188 def sum_to_0d(x): 189 """Sum x, returning a 0d array of the same class""" 190 assert_equal(x.ndim, 1) 191 return np.squeeze(np.sum(x, keepdims=True)) 192 193 a = np.ones((6, 3)).view(cls) 194 res = apply_along_axis(sum_to_0d, 0, a) 195 assert_(isinstance(res, cls)) 196 assert_array_equal(res, np.array([6, 6, 6]).view(cls)) 197 198 res = apply_along_axis(sum_to_0d, 1, a) 199 assert_(isinstance(res, cls)) 200 assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls)) 201 202 def test_axis_insertion(self, cls=np.ndarray): 203 def f1to2(x): 204 """produces an asymmetric non-square matrix from x""" 205 assert_equal(x.ndim, 1) 206 return (x[::-1] * x[1:, None]).view(cls) 207 208 a2d = np.arange(6 * 3).reshape((6, 3)) 209 210 # 2d insertion along first axis 211 actual = apply_along_axis(f1to2, 0, a2d) 212 expected = np.stack( 213 [f1to2(a2d[:, i]) for i in range(a2d.shape[1])], axis=-1 214 ).view(cls) 215 assert_equal(type(actual), type(expected)) 216 assert_equal(actual, expected) 217 218 # 2d insertion along last axis 219 actual = apply_along_axis(f1to2, 1, a2d) 220 expected = np.stack( 221 [f1to2(a2d[i, :]) for i in range(a2d.shape[0])], axis=0 222 ).view(cls) 223 assert_equal(type(actual), type(expected)) 224 assert_equal(actual, expected) 225 226 # 3d insertion along middle axis 227 a3d = np.arange(6 * 5 * 3).reshape((6, 5, 3)) 228 229 actual = apply_along_axis(f1to2, 1, a3d) 230 expected = np.stack( 231 [ 232 np.stack([f1to2(a3d[i, :, j]) for i in range(a3d.shape[0])], axis=0) 233 for j in range(a3d.shape[2]) 234 ], 235 axis=-1, 236 ).view(cls) 237 assert_equal(type(actual), type(expected)) 238 assert_equal(actual, expected) 239 240 def test_axis_insertion_ma(self): 241 def f1to2(x): 242 """produces an asymmetric non-square matrix from x""" 243 assert_equal(x.ndim, 1) 244 res = x[::-1] * x[1:, None] 245 return np.ma.masked_where(res % 5 == 0, res) 246 247 a = np.arange(6 * 3).reshape((6, 3)) 248 res = apply_along_axis(f1to2, 0, a) 249 assert_(isinstance(res, np.ma.masked_array)) 250 assert_equal(res.ndim, 3) 251 assert_array_equal(res[:, :, 0].mask, f1to2(a[:, 0]).mask) 252 assert_array_equal(res[:, :, 1].mask, f1to2(a[:, 1]).mask) 253 assert_array_equal(res[:, :, 2].mask, f1to2(a[:, 2]).mask) 254 255 def test_tuple_func1d(self): 256 def sample_1d(x): 257 return x[1], x[0] 258 259 res = np.apply_along_axis(sample_1d, 1, np.array([[1, 2], [3, 4]])) 260 assert_array_equal(res, np.array([[2, 1], [4, 3]])) 261 262 def test_empty(self): 263 # can't apply_along_axis when there's no chance to call the function 264 def never_call(x): 265 assert_(False) # should never be reached 266 267 a = np.empty((0, 0)) 268 assert_raises(ValueError, np.apply_along_axis, never_call, 0, a) 269 assert_raises(ValueError, np.apply_along_axis, never_call, 1, a) 270 271 # but it's sometimes ok with some non-zero dimensions 272 def empty_to_1(x): 273 assert_(len(x) == 0) 274 return 1 275 276 a = np.empty((10, 0)) 277 actual = np.apply_along_axis(empty_to_1, 1, a) 278 assert_equal(actual, np.ones(10)) 279 assert_raises(ValueError, np.apply_along_axis, empty_to_1, 0, a) 280 281 @skip # TypeError: descriptor 'union' for 'set' objects doesn't apply to a 'numpy.int64' object 282 def test_with_iterable_object(self): 283 # from issue 5248 284 d = np.array([[{1, 11}, {2, 22}, {3, 33}], [{4, 44}, {5, 55}, {6, 66}]]) 285 actual = np.apply_along_axis(lambda a: set.union(*a), 0, d) 286 expected = np.array([{1, 11, 4, 44}, {2, 22, 5, 55}, {3, 33, 6, 66}]) 287 288 assert_equal(actual, expected) 289 290 # issue 8642 - assert_equal doesn't detect this! 291 for i in np.ndindex(actual.shape): 292 assert_equal(type(actual[i]), type(expected[i])) 293 294 295@xfail # (reason="apply_over_axes not implemented") 296class TestApplyOverAxes(TestCase): 297 def test_simple(self): 298 a = np.arange(24).reshape(2, 3, 4) 299 aoa_a = apply_over_axes(np.sum, a, [0, 2]) 300 assert_array_equal(aoa_a, np.array([[[60], [92], [124]]])) 301 302 303class TestExpandDims(TestCase): 304 def test_functionality(self): 305 s = (2, 3, 4, 5) 306 a = np.empty(s) 307 for axis in range(-5, 4): 308 b = expand_dims(a, axis) 309 assert_(b.shape[axis] == 1) 310 assert_(np.squeeze(b).shape == s) 311 312 def test_axis_tuple(self): 313 a = np.empty((3, 3, 3)) 314 assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3) 315 assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1) 316 assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1) 317 assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3) 318 319 def test_axis_out_of_range(self): 320 s = (2, 3, 4, 5) 321 a = np.empty(s) 322 assert_raises(np.AxisError, expand_dims, a, -6) 323 assert_raises(np.AxisError, expand_dims, a, 5) 324 325 a = np.empty((3, 3, 3)) 326 assert_raises(np.AxisError, expand_dims, a, (0, -6)) 327 assert_raises(np.AxisError, expand_dims, a, (0, 5)) 328 329 def test_repeated_axis(self): 330 a = np.empty((3, 3, 3)) 331 assert_raises(ValueError, expand_dims, a, axis=(1, 1)) 332 333 334class TestArraySplit(TestCase): 335 def test_integer_0_split(self): 336 a = np.arange(10) 337 assert_raises(ValueError, array_split, a, 0) 338 339 def test_integer_split(self): 340 a = np.arange(10) 341 res = array_split(a, 1) 342 desired = [np.arange(10)] 343 compare_results(res, desired) 344 345 res = array_split(a, 2) 346 desired = [np.arange(5), np.arange(5, 10)] 347 compare_results(res, desired) 348 349 res = array_split(a, 3) 350 desired = [np.arange(4), np.arange(4, 7), np.arange(7, 10)] 351 compare_results(res, desired) 352 353 res = array_split(a, 4) 354 desired = [np.arange(3), np.arange(3, 6), np.arange(6, 8), np.arange(8, 10)] 355 compare_results(res, desired) 356 357 res = array_split(a, 5) 358 desired = [ 359 np.arange(2), 360 np.arange(2, 4), 361 np.arange(4, 6), 362 np.arange(6, 8), 363 np.arange(8, 10), 364 ] 365 compare_results(res, desired) 366 367 res = array_split(a, 6) 368 desired = [ 369 np.arange(2), 370 np.arange(2, 4), 371 np.arange(4, 6), 372 np.arange(6, 8), 373 np.arange(8, 9), 374 np.arange(9, 10), 375 ] 376 compare_results(res, desired) 377 378 res = array_split(a, 7) 379 desired = [ 380 np.arange(2), 381 np.arange(2, 4), 382 np.arange(4, 6), 383 np.arange(6, 7), 384 np.arange(7, 8), 385 np.arange(8, 9), 386 np.arange(9, 10), 387 ] 388 compare_results(res, desired) 389 390 res = array_split(a, 8) 391 desired = [ 392 np.arange(2), 393 np.arange(2, 4), 394 np.arange(4, 5), 395 np.arange(5, 6), 396 np.arange(6, 7), 397 np.arange(7, 8), 398 np.arange(8, 9), 399 np.arange(9, 10), 400 ] 401 compare_results(res, desired) 402 403 res = array_split(a, 9) 404 desired = [ 405 np.arange(2), 406 np.arange(2, 3), 407 np.arange(3, 4), 408 np.arange(4, 5), 409 np.arange(5, 6), 410 np.arange(6, 7), 411 np.arange(7, 8), 412 np.arange(8, 9), 413 np.arange(9, 10), 414 ] 415 compare_results(res, desired) 416 417 res = array_split(a, 10) 418 desired = [ 419 np.arange(1), 420 np.arange(1, 2), 421 np.arange(2, 3), 422 np.arange(3, 4), 423 np.arange(4, 5), 424 np.arange(5, 6), 425 np.arange(6, 7), 426 np.arange(7, 8), 427 np.arange(8, 9), 428 np.arange(9, 10), 429 ] 430 compare_results(res, desired) 431 432 res = array_split(a, 11) 433 desired = [ 434 np.arange(1), 435 np.arange(1, 2), 436 np.arange(2, 3), 437 np.arange(3, 4), 438 np.arange(4, 5), 439 np.arange(5, 6), 440 np.arange(6, 7), 441 np.arange(7, 8), 442 np.arange(8, 9), 443 np.arange(9, 10), 444 np.array([]), 445 ] 446 compare_results(res, desired) 447 448 def test_integer_split_2D_rows(self): 449 a = np.array([np.arange(10), np.arange(10)]) 450 res = array_split(a, 3, axis=0) 451 tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]), np.zeros((0, 10))] 452 compare_results(res, tgt) 453 assert_(a.dtype.type is res[-1].dtype.type) 454 455 # Same thing for manual splits: 456 res = array_split(a, [0, 1], axis=0) 457 tgt = [np.zeros((0, 10)), np.array([np.arange(10)]), np.array([np.arange(10)])] 458 compare_results(res, tgt) 459 assert_(a.dtype.type is res[-1].dtype.type) 460 461 def test_integer_split_2D_cols(self): 462 a = np.array([np.arange(10), np.arange(10)]) 463 res = array_split(a, 3, axis=-1) 464 desired = [ 465 np.array([np.arange(4), np.arange(4)]), 466 np.array([np.arange(4, 7), np.arange(4, 7)]), 467 np.array([np.arange(7, 10), np.arange(7, 10)]), 468 ] 469 compare_results(res, desired) 470 471 def test_integer_split_2D_default(self): 472 """This will fail if we change default axis""" 473 a = np.array([np.arange(10), np.arange(10)]) 474 res = array_split(a, 3) 475 tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]), np.zeros((0, 10))] 476 compare_results(res, tgt) 477 assert_(a.dtype.type is res[-1].dtype.type) 478 # perhaps should check higher dimensions 479 480 @skipif(not IS_64BIT, reason="Needs 64bit platform") 481 def test_integer_split_2D_rows_greater_max_int32(self): 482 a = np.broadcast_to([0], (1 << 32, 2)) 483 res = array_split(a, 4) 484 chunk = np.broadcast_to([0], (1 << 30, 2)) 485 tgt = [chunk] * 4 486 for i in range(len(tgt)): 487 assert_equal(res[i].shape, tgt[i].shape) 488 489 def test_index_split_simple(self): 490 a = np.arange(10) 491 indices = [1, 5, 7] 492 res = array_split(a, indices, axis=-1) 493 desired = [np.arange(0, 1), np.arange(1, 5), np.arange(5, 7), np.arange(7, 10)] 494 compare_results(res, desired) 495 496 def test_index_split_low_bound(self): 497 a = np.arange(10) 498 indices = [0, 5, 7] 499 res = array_split(a, indices, axis=-1) 500 desired = [np.array([]), np.arange(0, 5), np.arange(5, 7), np.arange(7, 10)] 501 compare_results(res, desired) 502 503 def test_index_split_high_bound(self): 504 a = np.arange(10) 505 indices = [0, 5, 7, 10, 12] 506 res = array_split(a, indices, axis=-1) 507 desired = [ 508 np.array([]), 509 np.arange(0, 5), 510 np.arange(5, 7), 511 np.arange(7, 10), 512 np.array([]), 513 np.array([]), 514 ] 515 compare_results(res, desired) 516 517 518class TestSplit(TestCase): 519 # The split function is essentially the same as array_split, 520 # except that it test if splitting will result in an 521 # equal split. Only test for this case. 522 523 def test_equal_split(self): 524 a = np.arange(10) 525 res = split(a, 2) 526 desired = [np.arange(5), np.arange(5, 10)] 527 compare_results(res, desired) 528 529 def test_unequal_split(self): 530 a = np.arange(10) 531 assert_raises(ValueError, split, a, 3) 532 533 534class TestColumnStack(TestCase): 535 def test_non_iterable(self): 536 assert_raises(TypeError, column_stack, 1) 537 538 def test_1D_arrays(self): 539 # example from docstring 540 a = np.array((1, 2, 3)) 541 b = np.array((2, 3, 4)) 542 expected = np.array([[1, 2], [2, 3], [3, 4]]) 543 actual = np.column_stack((a, b)) 544 assert_equal(actual, expected) 545 546 def test_2D_arrays(self): 547 # same as hstack 2D docstring example 548 a = np.array([[1], [2], [3]]) 549 b = np.array([[2], [3], [4]]) 550 expected = np.array([[1, 2], [2, 3], [3, 4]]) 551 actual = np.column_stack((a, b)) 552 assert_equal(actual, expected) 553 554 def test_generator(self): 555 # numpy 1.24 emits a warning but we don't 556 # with assert_warns(FutureWarning): 557 column_stack([np.arange(3) for _ in range(2)]) 558 559 560class TestDstack(TestCase): 561 def test_non_iterable(self): 562 assert_raises(TypeError, dstack, 1) 563 564 def test_0D_array(self): 565 a = np.array(1) 566 b = np.array(2) 567 res = dstack([a, b]) 568 desired = np.array([[[1, 2]]]) 569 assert_array_equal(res, desired) 570 571 def test_1D_array(self): 572 a = np.array([1]) 573 b = np.array([2]) 574 res = dstack([a, b]) 575 desired = np.array([[[1, 2]]]) 576 assert_array_equal(res, desired) 577 578 def test_2D_array(self): 579 a = np.array([[1], [2]]) 580 b = np.array([[1], [2]]) 581 res = dstack([a, b]) 582 desired = np.array( 583 [ 584 [[1, 1]], 585 [ 586 [ 587 2, 588 2, 589 ] 590 ], 591 ] 592 ) 593 assert_array_equal(res, desired) 594 595 def test_2D_array2(self): 596 a = np.array([1, 2]) 597 b = np.array([1, 2]) 598 res = dstack([a, b]) 599 desired = np.array([[[1, 1], [2, 2]]]) 600 assert_array_equal(res, desired) 601 602 def test_generator(self): 603 # numpy 1.24 emits a warning but we don't 604 # with assert_warns(FutureWarning): 605 dstack([np.arange(3) for _ in range(2)]) 606 607 608# array_split has more comprehensive test of splitting. 609# only do simple test on hsplit, vsplit, and dsplit 610class TestHsplit(TestCase): 611 """Only testing for integer splits.""" 612 613 def test_non_iterable(self): 614 assert_raises(ValueError, hsplit, 1, 1) 615 616 def test_0D_array(self): 617 a = np.array(1) 618 try: 619 hsplit(a, 2) 620 assert_(0) 621 except ValueError: 622 pass 623 624 def test_1D_array(self): 625 a = np.array([1, 2, 3, 4]) 626 res = hsplit(a, 2) 627 desired = [np.array([1, 2]), np.array([3, 4])] 628 compare_results(res, desired) 629 630 def test_2D_array(self): 631 a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]) 632 res = hsplit(a, 2) 633 desired = [np.array([[1, 2], [1, 2]]), np.array([[3, 4], [3, 4]])] 634 compare_results(res, desired) 635 636 637class TestVsplit(TestCase): 638 """Only testing for integer splits.""" 639 640 def test_non_iterable(self): 641 assert_raises(ValueError, vsplit, 1, 1) 642 643 def test_0D_array(self): 644 a = np.array(1) 645 assert_raises(ValueError, vsplit, a, 2) 646 647 def test_1D_array(self): 648 a = np.array([1, 2, 3, 4]) 649 try: 650 vsplit(a, 2) 651 assert_(0) 652 except ValueError: 653 pass 654 655 def test_2D_array(self): 656 a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]) 657 res = vsplit(a, 2) 658 desired = [np.array([[1, 2, 3, 4]]), np.array([[1, 2, 3, 4]])] 659 compare_results(res, desired) 660 661 662class TestDsplit(TestCase): 663 # Only testing for integer splits. 664 def test_non_iterable(self): 665 assert_raises(ValueError, dsplit, 1, 1) 666 667 def test_0D_array(self): 668 a = np.array(1) 669 assert_raises(ValueError, dsplit, a, 2) 670 671 def test_1D_array(self): 672 a = np.array([1, 2, 3, 4]) 673 assert_raises(ValueError, dsplit, a, 2) 674 675 def test_2D_array(self): 676 a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]]) 677 try: 678 dsplit(a, 2) 679 assert_(0) 680 except ValueError: 681 pass 682 683 def test_3D_array(self): 684 a = np.array([[[1, 2, 3, 4], [1, 2, 3, 4]], [[1, 2, 3, 4], [1, 2, 3, 4]]]) 685 res = dsplit(a, 2) 686 desired = [ 687 np.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]), 688 np.array([[[3, 4], [3, 4]], [[3, 4], [3, 4]]]), 689 ] 690 compare_results(res, desired) 691 692 693class TestSqueeze(TestCase): 694 def test_basic(self): 695 a = rand(20, 10, 10, 1, 1) 696 b = rand(20, 1, 10, 1, 20) 697 c = rand(1, 1, 20, 10) 698 assert_array_equal(np.squeeze(a), np.reshape(a, (20, 10, 10))) 699 assert_array_equal(np.squeeze(b), np.reshape(b, (20, 10, 20))) 700 assert_array_equal(np.squeeze(c), np.reshape(c, (20, 10))) 701 702 # Squeezing to 0-dim should still give an ndarray 703 a = [[[1.5]]] 704 res = np.squeeze(a) 705 assert_equal(res, 1.5) 706 assert_equal(res.ndim, 0) 707 assert type(res) is np.ndarray 708 709 @xfailIfTorchDynamo 710 def test_basic_2(self): 711 aa = np.ones((3, 1, 4, 1, 1)) 712 assert aa.squeeze().tensor._base is aa.tensor 713 714 def test_squeeze_axis(self): 715 A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]] 716 assert_equal(np.squeeze(A).shape, (3, 3)) 717 assert_equal(np.squeeze(A, axis=()), A) 718 719 assert_equal(np.squeeze(np.zeros((1, 3, 1))).shape, (3,)) 720 assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=0).shape, (3, 1)) 721 assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=-1).shape, (1, 3)) 722 assert_equal(np.squeeze(np.zeros((1, 3, 1)), axis=2).shape, (1, 3)) 723 assert_equal(np.squeeze([np.zeros((3, 1))]).shape, (3,)) 724 assert_equal(np.squeeze([np.zeros((3, 1))], axis=0).shape, (3, 1)) 725 assert_equal(np.squeeze([np.zeros((3, 1))], axis=2).shape, (1, 3)) 726 assert_equal(np.squeeze([np.zeros((3, 1))], axis=-1).shape, (1, 3)) 727 728 def test_squeeze_type(self): 729 # Ticket #133 730 a = np.array([3]) 731 b = np.array(3) 732 assert type(a.squeeze()) is np.ndarray 733 assert type(b.squeeze()) is np.ndarray 734 735 @skip(reason="XXX: order='F' not implemented") 736 def test_squeeze_contiguous(self): 737 # Similar to GitHub issue #387 738 a = np.zeros((1, 2)).squeeze() 739 b = np.zeros((2, 2, 2), order="F")[:, :, ::2].squeeze() 740 assert_(a.flags.c_contiguous) 741 assert_(a.flags.f_contiguous) 742 assert_(b.flags.f_contiguous) 743 744 @xpassIfTorchDynamo # (reason="XXX: noop in torch, while numpy raises") 745 def test_squeeze_axis_handling(self): 746 with assert_raises(ValueError): 747 np.squeeze(np.array([[1], [2], [3]]), axis=0) 748 749 750@instantiate_parametrized_tests 751class TestKron(TestCase): 752 def test_basic(self): 753 # Using 0-dimensional ndarray 754 a = np.array(1) 755 b = np.array([[1, 2], [3, 4]]) 756 k = np.array([[1, 2], [3, 4]]) 757 assert_array_equal(np.kron(a, b), k) 758 a = np.array([[1, 2], [3, 4]]) 759 b = np.array(1) 760 assert_array_equal(np.kron(a, b), k) 761 762 # Using 1-dimensional ndarray 763 a = np.array([3]) 764 b = np.array([[1, 2], [3, 4]]) 765 k = np.array([[3, 6], [9, 12]]) 766 assert_array_equal(np.kron(a, b), k) 767 a = np.array([[1, 2], [3, 4]]) 768 b = np.array([3]) 769 assert_array_equal(np.kron(a, b), k) 770 771 # Using 3-dimensional ndarray 772 a = np.array([[[1]], [[2]]]) 773 b = np.array([[1, 2], [3, 4]]) 774 k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) 775 assert_array_equal(np.kron(a, b), k) 776 a = np.array([[1, 2], [3, 4]]) 777 b = np.array([[[1]], [[2]]]) 778 k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]]) 779 assert_array_equal(np.kron(a, b), k) 780 781 @skip(reason="NP_VER: fails on CI") 782 @parametrize( 783 "shape_a,shape_b", 784 [ 785 ((1, 1), (1, 1)), 786 ((1, 2, 3), (4, 5, 6)), 787 ((2, 2), (2, 2, 2)), 788 ((1, 0), (1, 1)), 789 ((2, 0, 2), (2, 2)), 790 ((2, 0, 0, 2), (2, 0, 2)), 791 ], 792 ) 793 def test_kron_shape(self, shape_a, shape_b): 794 a = np.ones(shape_a) 795 b = np.ones(shape_b) 796 normalised_shape_a = (1,) * max(0, len(shape_b) - len(shape_a)) + shape_a 797 normalised_shape_b = (1,) * max(0, len(shape_a) - len(shape_b)) + shape_b 798 expected_shape = np.multiply(normalised_shape_a, normalised_shape_b) 799 800 k = np.kron(a, b) 801 assert np.array_equal(k.shape, expected_shape), "Unexpected shape from kron" 802 803 804class TestTile(TestCase): 805 def test_basic(self): 806 a = np.array([0, 1, 2]) 807 b = [[1, 2], [3, 4]] 808 assert_equal(tile(a, 2), [0, 1, 2, 0, 1, 2]) 809 assert_equal(tile(a, (2, 2)), [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]]) 810 assert_equal(tile(a, (1, 2)), [[0, 1, 2, 0, 1, 2]]) 811 assert_equal(tile(b, 2), [[1, 2, 1, 2], [3, 4, 3, 4]]) 812 assert_equal(tile(b, (2, 1)), [[1, 2], [3, 4], [1, 2], [3, 4]]) 813 assert_equal( 814 tile(b, (2, 2)), [[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]] 815 ) 816 817 def test_tile_one_repetition_on_array_gh4679(self): 818 a = np.arange(5) 819 b = tile(a, 1) 820 b += 2 821 assert_equal(a, np.arange(5)) 822 823 def test_empty(self): 824 a = np.array([[[]]]) 825 b = np.array([[], []]) 826 c = tile(b, 2).shape 827 d = tile(a, (3, 2, 5)).shape 828 assert_equal(c, (2, 0)) 829 assert_equal(d, (3, 2, 0)) 830 831 def test_kroncompare(self): 832 reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)] 833 shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)] 834 for s in shape: 835 b = randint(0, 10, size=s) 836 for r in reps: 837 a = np.ones(r, b.dtype) 838 large = tile(b, r) 839 klarge = kron(a, b) 840 assert_equal(large, klarge) 841 842 843@xfail # Maybe implement one day 844class TestMayShareMemory(TestCase): 845 def test_basic(self): 846 d = np.ones((50, 60)) 847 d2 = np.ones((30, 60, 6)) 848 assert_(np.may_share_memory(d, d)) 849 assert_(np.may_share_memory(d, d[::-1])) 850 assert_(np.may_share_memory(d, d[::2])) 851 assert_(np.may_share_memory(d, d[1:, ::-1])) 852 853 assert_(not np.may_share_memory(d[::-1], d2)) 854 assert_(not np.may_share_memory(d[::2], d2)) 855 assert_(not np.may_share_memory(d[1:, ::-1], d2)) 856 assert_(np.may_share_memory(d2[1:, ::-1], d2)) 857 858 859# Utility 860def compare_results(res, desired): 861 """Compare lists of arrays.""" 862 if len(res) != len(desired): 863 raise ValueError("Iterables have different lengths") 864 # See also PEP 618 for Python 3.10 865 for x, y in zip(res, desired): 866 assert_array_equal(x, y) 867 868 869if __name__ == "__main__": 870 run_tests() 871