1# Owner(s): ["module: intel"] 2 3import itertools 4import math 5import random 6from functools import partial 7from itertools import product 8 9import numpy as np 10 11import torch 12from torch.testing import make_tensor 13from torch.testing._internal.common_device_type import ( 14 dtypes, 15 instantiate_device_type_tests, 16 precisionOverride, 17) 18from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase 19 20 21class TestBasicGEMM(TestCase): 22 def _test_addmm_addmv( 23 self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None 24 ): 25 dtype = t.dtype 26 numpy_dtype = dtype 27 if dtype in {torch.bfloat16, torch.half}: 28 numpy_dtype = torch.float 29 if dtype.is_complex: 30 alpha = 0.9 + 0.3j if alpha is None else alpha 31 beta = 0.5 + 0.6j if beta is None else beta 32 else: 33 alpha = 1.2 if alpha is None else alpha 34 beta = 0.8 if beta is None else beta 35 if activation == "gelu": 36 res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) 37 else: 38 res1 = f(t, m, v, alpha=alpha, beta=beta) 39 res2 = torch.full_like(res1, math.nan) 40 if transpose_out: 41 res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 42 if activation == "gelu": 43 f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) 44 else: 45 f(t, m, v, alpha=alpha, beta=beta, out=res2) 46 m.to(numpy_dtype).cpu().numpy() 47 v.to(numpy_dtype).cpu().numpy() 48 res3 = alpha * ( 49 m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy() 50 ) 51 if beta != 0: 52 res3 += (beta * t).to(numpy_dtype).cpu().numpy() 53 if activation == "relu": 54 res3 = res3 * (res3 > 0) 55 elif activation == "gelu": 56 res3_t = torch.from_numpy(res3).to(dtype) 57 approximate = "tanh" if t.is_cuda else "none" 58 res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) 59 res3 = res3_t.to(numpy_dtype).cpu().numpy() 60 else: 61 assert activation is None, f"unsupported activation {activation}" 62 res3 = torch.from_numpy(res3).to(dtype) 63 self.assertEqual(res1, res2) 64 self.assertEqual(res1, res3) 65 66 def _test_addmm_impl(self, func, activation, device, dtype): 67 M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 68 m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) 69 m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 70 self._test_addmm_addmv(func, M, m1, m2, activation=activation) 71 72 # vector-shaped bias and beta=1 result in epilogue fusion in CUDA 73 V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device) 74 self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) 75 76 # Test 0-strided 77 M = ( 78 torch.randn(10, 1, device="cpu", dtype=torch.float32) 79 .to(dtype) 80 .expand(10, 25) 81 .to(device) 82 ) 83 m1 = ( 84 torch.randn(10, 1, device="cpu", dtype=torch.float32) 85 .to(dtype) 86 .expand(10, 50) 87 .to(device) 88 ) 89 m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 90 self._test_addmm_addmv(func, M, m1, m2, activation=activation) 91 92 # Test beta=0, M=nan 93 M = ( 94 torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32) 95 .to(dtype) 96 .to(device) 97 ) 98 m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) 99 m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 100 self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) 101 102 # Test transpose 103 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 104 105 def maybe_transpose(cond, m): 106 if not cond: 107 return m 108 return m.t().clone(memory_format=torch.contiguous_format).t() 109 110 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 111 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 112 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 113 self._test_addmm_addmv( 114 func, M, m1, m2, transpose_out=t4, activation=activation 115 ) 116 117 if t1: 118 # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) 119 self._test_addmm_addmv( 120 func, 121 V, 122 m1, 123 m2, 124 beta=1, 125 transpose_out=t4, 126 activation=activation, 127 ) 128 129 @precisionOverride( 130 { 131 torch.float: 1e-4, 132 torch.half: 1e-1, 133 } 134 ) 135 @dtypes(torch.float32, torch.half) 136 def test_addmm(self, device, dtype): 137 self._test_addmm_impl(torch.addmm, None, device, dtype) 138 139 @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4}) 140 @dtypes(torch.bfloat16, torch.half, torch.float) 141 def test_addmv(self, device, dtype): 142 # have to use torch.randn(...).to(bfloat16) instead of 143 # torch.randn(..., dtype=bfloat16). randn does not support 144 # bfloat16 yet. 145 # "*0.2" to reduce errors for low precision 146 ts = [ 147 0.2 * torch.randn(50, device=device).to(dtype), 148 0.2 * torch.randn(1, device=device).to(dtype).expand(50), 149 ] 150 vs = [ 151 0.2 * torch.randn(100, device=device).to(dtype), 152 0.2 153 * torch.ones(1, device=device) 154 .to(dtype) 155 .expand(100), # to reduce errors for low precision 156 ] 157 ms = [ 158 # 0d 159 0.2 160 * torch.ones((), device=device) 161 .to(dtype) 162 .expand(50, 100), # to reduce errors for low precision 163 # 1d 164 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), 165 # this initialization reduces errors for low precision for broadcasted matrices 166 # by making sure that intermediate and result values are exactly representable 167 # in low precision type 168 0.2 169 * torch.randint(3, (50, 1), dtype=torch.float, device=device) 170 .to(dtype) 171 .expand(50, 100), 172 # 2d 173 0.2 * torch.randn((50, 100), device=device).to(dtype), 174 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), 175 ] 176 for m, v, t in itertools.product(ms, vs, ts): 177 self._test_addmm_addmv(torch.addmv, t, m, v) 178 # Test beta=0, t=nan 179 t = torch.full((50,), math.nan, device=device).to(dtype) 180 for m, v in itertools.product(ms, vs): 181 self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) 182 183 @dtypes( 184 torch.half, 185 torch.float32, 186 ) 187 def test_mm(self, device, dtype): 188 def _test_mm(n, m, p, dtype, genf): 189 # helper function 190 def matrixmultiply(mat1, mat2): 191 n = mat1.size(0) 192 m = mat1.size(1) 193 p = mat2.size(1) 194 dtype_ = torch.float if dtype == torch.half else dtype 195 if dtype == torch.half: 196 mat1 = mat1.float() 197 mat2 = mat2.float() 198 res = torch.zeros(n, p, dtype=dtype_, device=device) 199 for i, j in iter_indices(res): 200 res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) 201 return res.half() if dtype == torch.half else res 202 203 # contiguous case 204 mat1 = genf(n, m) 205 mat2 = genf(m, p) 206 res = torch.mm(mat1, mat2) 207 208 res2 = matrixmultiply(mat1, mat2) 209 self.assertEqual(res, res2) 210 211 # non contiguous case 1 212 mat1 = genf(n, m) 213 mat2 = genf(p, m).t() 214 res = torch.mm(mat1, mat2) 215 216 res2 = matrixmultiply(mat1, mat2) 217 self.assertEqual(res, res2) 218 219 # non contiguous case 2 220 mat1 = genf(m, n).t() 221 mat2 = genf(m, p) 222 res = torch.mm(mat1, mat2) 223 224 res2 = matrixmultiply(mat1, mat2) 225 self.assertEqual(res, res2) 226 227 # non contiguous case 3 228 mat1 = genf(m, n).t() 229 mat2 = genf(p, m).t() 230 res = torch.mm(mat1, mat2) 231 232 res2 = matrixmultiply(mat1, mat2) 233 self.assertEqual(res, res2) 234 235 # test with zero stride 236 mat1 = genf(n, m) 237 mat2 = genf(m, 1).expand(m, p) 238 res = torch.mm(mat1, mat2) 239 240 res2 = matrixmultiply(mat1, mat2) 241 self.assertEqual(res, res2) 242 243 # explicitly exercise the _out variant in torch.mm(). 244 # contiguous case 245 mat1 = genf(n, m) 246 mat2 = genf(m, p) 247 res = genf(n, p) 248 torch.mm(mat1, mat2, out=res) 249 250 res2 = matrixmultiply(mat1, mat2) 251 self.assertEqual(res, res2) 252 253 # explicitly exercise the _out variant in torch.mm(). 254 # non contiguous case 3 255 mat1 = genf(m, n).t() 256 mat2 = genf(p, m).t() 257 res = genf(n, p) 258 torch.mm(mat1, mat2, out=res) 259 260 res2 = matrixmultiply(mat1, mat2) 261 self.assertEqual(res, res2) 262 263 def genf_int(x, y): 264 return torch.randint(0, 100, (x, y), dtype=dtype, device=device) 265 266 def genf_bfloat(x, y): 267 return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 268 269 def genf_float(x, y): 270 return torch.randn(x, y, dtype=dtype, device=device) 271 272 def genf_Half(x, y): 273 return torch.randn(x, y, dtype=dtype, device=device) 274 275 for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: 276 if (dtype == torch.int32) or (dtype == torch.int64): 277 genf = genf_int 278 elif dtype == torch.bfloat16: 279 genf = genf_bfloat 280 elif dtype == torch.half: 281 genf = genf_Half 282 else: 283 genf = genf_float 284 285 _test_mm(n, m, p, dtype, genf) 286 287 @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 288 @dtypes(torch.float32, torch.bfloat16, torch.half) 289 def test_bmm(self, device, dtype): 290 batch_sizes = [1, 10] 291 M, N, O = 23, 15, 12 292 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 293 294 def invert_perm(p): 295 d = {x: i for i, x in enumerate(p)} 296 return (d[0], d[1], d[2]) 297 298 def generate_inputs(num_batches): 299 # transposed tensors 300 for perm1, perm2 in itertools.product( 301 itertools.permutations((0, 1, 2)), repeat=2 302 ): 303 b1 = make_tensor( 304 (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1 305 ) 306 b2 = make_tensor( 307 (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1 308 ) 309 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 310 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 311 yield b1, b2 312 # broadcasting tensors 313 for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): 314 shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) 315 shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) 316 b1 = make_tensor( 317 shape1, dtype=dtype, device=device, low=-0.1, high=0.1 318 ).expand(num_batches, M, N) 319 b2 = make_tensor( 320 shape2, dtype=dtype, device=device, low=-0.1, high=0.1 321 ).expand(num_batches, N, O) 322 yield b1, b2 323 # zero-sized tensors 324 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 325 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 326 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 327 b1 = torch.randn(shape1, dtype=dtype, device=device) 328 b2 = torch.randn(shape2, dtype=dtype, device=device) 329 yield b1, b2 330 331 for num_batches in batch_sizes: 332 for (b1, b2), perm3 in itertools.product( 333 generate_inputs(num_batches), itertools.permutations((0, 1, 2)) 334 ): 335 res1 = torch.bmm(b1, b2) 336 res2 = ( 337 torch.full( 338 (num_batches, M, O), math.nan, dtype=dtype, device=device 339 ) 340 .permute(perm3) 341 .contiguous() 342 .permute(invert_perm(perm3)) 343 ) 344 torch.bmm(b1, b2, out=res2) 345 expect = torch.from_numpy( 346 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 347 ).to(device=device, dtype=dtype) 348 self.assertEqual(expect, res1) 349 self.assertEqual(expect, res2) 350 351 if self.device_type == "cuda": 352 # check that mixed arguments are rejected 353 self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) 354 self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) 355 self.assertRaises( 356 RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()) 357 ) 358 359 def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): 360 getattr(out_tensor, func + "_")(b1, b2) 361 self.assertEqual(out_tensor, ref) 362 res3 = out_tensor.clone() 363 364 with self.assertWarnsOnceRegex( 365 UserWarning, f"This overload of {func}_ is deprecated" 366 ): 367 getattr(out_tensor, func + "_")(1, b1, b2) 368 self.assertEqual(out_tensor, ref * 2), 369 getattr(res3, func + "_")(b1, b2, beta=1) 370 self.assertEqual(out_tensor, res3) 371 372 with self.assertWarnsOnceRegex( 373 UserWarning, f"This overload of {func}_ is deprecated" 374 ): 375 getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2) 376 self.assertEqual(out_tensor, ref * 2.5) 377 getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5) 378 self.assertEqual(out_tensor, res3) 379 380 with self.assertWarnsOnceRegex( 381 UserWarning, f"This overload of {func} is deprecated" 382 ): 383 self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) 384 385 res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5) 386 self.assertEqual(res4, ref * 3), 387 388 nan = torch.full_like(out_tensor, math.nan) 389 res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) 390 self.assertEqual(res5, ref) 391 392 if b1.is_complex(): 393 res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j) 394 self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref) 395 else: 396 res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5) 397 self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref) 398 399 res7 = torch.full_like(out_tensor, math.nan) 400 getattr(torch, func)(nan, b1, b2, beta=0, out=res7) 401 self.assertEqual(res7, ref) 402 403 @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 404 @dtypes(torch.float32, torch.bfloat16, torch.half) 405 def test_addbmm(self, device, dtype): 406 num_batches = 2 407 M, N, O = 16, 17, 18 408 409 is_supported = True 410 411 if not is_supported: 412 b1 = make_tensor( 413 (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 414 ) 415 b2 = make_tensor( 416 (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 417 ) 418 t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) 419 self.assertRaisesRegex( 420 RuntimeError, 421 "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 422 lambda: torch.addbmm(t, b1, b2), 423 ) 424 return 425 426 def invert_perm(p): 427 d = {x: i for i, x in enumerate(p)} 428 return (d[0], d[1], d[2]) 429 430 def generate_tensor(): 431 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 432 # transposed tensors 433 for perm1, perm2 in itertools.product( 434 itertools.permutations((0, 1, 2)), repeat=2 435 ): 436 for perm3 in itertools.permutations((0, 1)): 437 b1 = ( 438 make_tensor( 439 (num_batches, M, N), 440 dtype=dtype, 441 device=device, 442 low=-1, 443 high=1, 444 ) 445 * 0.1 446 ) 447 b2 = ( 448 make_tensor( 449 (num_batches, N, O), 450 dtype=dtype, 451 device=device, 452 low=-1, 453 high=1, 454 ) 455 * 0.1 456 ) 457 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 458 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 459 ref = ( 460 torch.from_numpy( 461 b1.to(numpy_dtype).cpu().numpy() 462 @ b2.to(numpy_dtype).cpu().numpy() 463 ) 464 .to(device=device, dtype=dtype) 465 .sum(0) 466 ) 467 out_tensor = ( 468 torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) 469 ) 470 yield b1, b2, ref, out_tensor 471 # broadcasting tensors 472 for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 473 shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 474 shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 475 b1 = ( 476 make_tensor( 477 shape1, dtype=dtype, device=device, low=-1, high=1 478 ).expand(num_batches, M, N) 479 * 0.1 480 ) 481 b2 = ( 482 make_tensor( 483 shape2, dtype=dtype, device=device, low=-1, high=1 484 ).expand(num_batches, N, O) 485 * 0.1 486 ) 487 ref = ( 488 torch.from_numpy( 489 b1.to(numpy_dtype).cpu().numpy() 490 @ b2.to(numpy_dtype).cpu().numpy() 491 ) 492 .to(device=device, dtype=dtype) 493 .sum(0) 494 ) 495 out_tensor = torch.zeros_like(ref) 496 yield b1, b2, ref, out_tensor 497 # zero-sized tensors 498 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 499 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 500 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 501 b1 = ( 502 make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) 503 * 0.1 504 ) 505 b2 = ( 506 make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) 507 * 0.1 508 ) 509 ref = ( 510 torch.from_numpy( 511 b1.to(numpy_dtype).cpu().numpy() 512 @ b2.to(numpy_dtype).cpu().numpy() 513 ) 514 .to(device=device, dtype=dtype) 515 .sum(0) 516 ) 517 out_tensor = torch.zeros_like(ref) 518 yield b1, b2, ref, out_tensor 519 520 for b1, b2, ref, out_tensor in generate_tensor(): 521 self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) 522 523 @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) 524 @dtypes(torch.float32, torch.bfloat16, torch.half) 525 def test_baddbmm(self, device, dtype): 526 num_batches = 10 527 M, N, O = 12, 8, 50 528 529 def invert_perm(p): 530 d = {x: i for i, x in enumerate(p)} 531 return (d[0], d[1], d[2]) 532 533 def generate_tensor(): 534 numpy_dtype = ( 535 dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 536 ) 537 # transposed tensors 538 for perm1, perm2, perm3 in itertools.product( 539 itertools.permutations((0, 1, 2)), repeat=3 540 ): 541 b1 = make_tensor( 542 (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 543 ) 544 b2 = make_tensor( 545 (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 546 ) 547 b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 548 b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 549 ref = torch.from_numpy( 550 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 551 ).to(device=device, dtype=dtype) 552 out_tensor = torch.zeros_like(ref) 553 out_tensor = ( 554 out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) 555 ) 556 yield b1, b2, ref, out_tensor 557 # broadcasting tensors 558 for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 559 shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 560 shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 561 b1 = make_tensor( 562 shape1, dtype=dtype, device=device, low=-1, high=1 563 ).expand(num_batches, M, N) 564 b2 = make_tensor( 565 shape2, dtype=dtype, device=device, low=-1, high=1 566 ).expand(num_batches, N, O) 567 ref = torch.from_numpy( 568 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 569 ).to(device=device, dtype=dtype) 570 out_tensor = torch.zeros_like(ref) 571 yield b1, b2, ref, out_tensor 572 # zero-sized tensors 573 for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 574 shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 575 shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 576 b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) 577 b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) 578 ref = torch.from_numpy( 579 b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 580 ).to(device=device, dtype=dtype) 581 out_tensor = torch.zeros_like(ref) 582 yield b1, b2, ref, out_tensor 583 584 for b1, b2, ref, out_tensor in generate_tensor(): 585 self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) 586 587 def test_tensordot(self, device): 588 a = torch.arange(60.0, device=device).reshape(3, 4, 5) 589 b = torch.arange(24.0, device=device).reshape(4, 3, 2) 590 c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() 591 cn = torch.from_numpy( 592 np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1])) 593 ) 594 self.assertEqual(c, cn) 595 596 cout = torch.zeros((5, 2), device=device) 597 torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() 598 self.assertEqual(c, cout) 599 600 a = torch.randn(2, 3, 4, 5, device=device) 601 b = torch.randn(4, 5, 6, 7, device=device) 602 c = torch.tensordot(a, b, dims=2).cpu() 603 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) 604 605 with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): 606 torch.tensordot(a, b, dims=-1) 607 608 self.assertEqual(c, cn) 609 c = torch.tensordot(a, b).cpu() 610 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) 611 self.assertEqual(c, cn) 612 613 a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0) 614 an = torch.from_numpy( 615 np.tensordot( 616 np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0 617 ) 618 ) 619 self.assertEqual(a, an) 620 621 @dtypes(torch.float) 622 @precisionOverride({torch.float32: 1e-4}) 623 def test_1_sized_with_0_strided(self, device, dtype): 624 a = make_tensor((8, 1, 64), dtype=dtype, device=device) 625 a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) 626 b = make_tensor((8, 64, 512), dtype=dtype, device=device) 627 b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) 628 res = torch.bmm(a_strided, b_strided) 629 expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to( 630 device=device, dtype=dtype 631 ) 632 self.assertEqual(expect, res) 633 634 def _select_broadcastable_dims(self, dims_full=None): 635 # select full dimensionality 636 if dims_full is None: 637 dims_full = [] 638 ndims = random.randint(1, 4) 639 dims_full = [random.randint(1, 8) for _ in range(ndims)] 640 else: 641 ndims = len(dims_full) 642 643 # select actual dimensions for ops: 644 # larger: full ndims, individual sizes may be reduced 645 # smaller: possibly reduced ndims, sizes may be reduced 646 smaller_ndims = random.randint(1, ndims) 647 dims_small = [] 648 dims_large = [] 649 for i in range(ndims - 1, -1, -1): 650 j = random.randint(1, 3) 651 if j == 1: # no reduced singleton dimension 652 ds = dims_full[i] 653 dl = dims_full[i] 654 elif j == 2: # larger may have reduced singleton dimension 655 ds = dims_full[i] 656 dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] 657 elif j == 3: # smaller may have reduced singleton dimension 658 ds = 1 659 dl = dims_full[i] 660 dims_large = [dl] + dims_large 661 if len(dims_small) < smaller_ndims: 662 dims_small = [ds] + dims_small 663 return (dims_small, dims_large, dims_full) 664 665 def test_broadcast_fused_matmul(self, device): 666 fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] 667 668 for fn in fns: 669 batch_dim = random.randint(1, 8) 670 n_dim = random.randint(1, 8) 671 m_dim = random.randint(1, 8) 672 p_dim = random.randint(1, 8) 673 674 def dims_full_for_fn(): 675 if fn == "baddbmm": 676 return ( 677 [batch_dim, n_dim, p_dim], 678 [batch_dim, n_dim, m_dim], 679 [batch_dim, m_dim, p_dim], 680 ) 681 elif fn == "addbmm": 682 return ( 683 [n_dim, p_dim], 684 [batch_dim, n_dim, m_dim], 685 [batch_dim, m_dim, p_dim], 686 ) 687 elif fn == "addmm": 688 return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) 689 elif fn == "addmv": 690 return ([n_dim], [n_dim, m_dim], [m_dim]) 691 elif fn == "addr": 692 return ([n_dim, m_dim], [n_dim], [m_dim]) 693 else: 694 raise AssertionError("unknown function") 695 696 (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() 697 (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) 698 699 t0_small = torch.randn(*t0_dims_small, device=device).float() 700 t1 = torch.randn(*t1_dims, device=device).float() 701 t2 = torch.randn(*t2_dims, device=device).float() 702 703 t0_full = t0_small.expand(*t0_dims_full).to(device) 704 705 fntorch = getattr(torch, fn) 706 r0 = fntorch(t0_small, t1, t2) 707 r1 = fntorch(t0_full, t1, t2) 708 self.assertEqual(r0, r1) 709 710 @dtypes(torch.float32) 711 def test_strided_mm_bmm(self, device, dtype): 712 # Tests strided view case with stride smaller than corresponding dimension size 713 x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device) 714 new_shape = [2, 2, 2] 715 new_stride = [3, 1, 1] 716 sx = torch.as_strided(x, size=new_shape, stride=new_stride) 717 718 torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 719 np_fn = lambda x: np.matmul(x, x) # noqa: E731 720 self.compare_with_numpy(torch_fn, np_fn, sx) 721 722 torch_fn = lambda x: torch.mm(x, x) # noqa: E731 723 self.compare_with_numpy(torch_fn, np_fn, sx[0]) 724 725 def test_mm_empty_inputs_mixed_dtype_errors(self, device): 726 a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) 727 b = torch.randn(10, 20, dtype=torch.float32, device=device) 728 with self.assertRaisesRegex( 729 RuntimeError, "expected .* and .* to have the same dtype, but got:" 730 ): 731 torch.mm(a, b) 732 733 def test_matmul_45724(self, device): 734 # https://github.com/pytorch/pytorch/issues/45724 735 a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) 736 b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) 737 c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) 738 cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half() 739 torch.matmul(a, b, out=c) 740 self.assertEqual(c, cpu_result) 741 742 @dtypes( 743 torch.int16, 744 torch.int32, 745 torch.int64, 746 torch.float16, 747 torch.float32, 748 torch.float64, 749 ) 750 def test_baddbmm_input_dtypes_compatibility(self, device, dtype): 751 batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 752 batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 753 input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) 754 if dtype != torch.float32: 755 with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): 756 y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) 757 else: 758 out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) 759 y_ref = torch.bmm(batch1, batch2) 760 y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) 761 self.assertEqual(out, y_ref) 762 763 @dtypes(torch.float) 764 def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): 765 for shape in [[3, 2, 2], [2, 20, 20]]: 766 mat1, mat2 = ( 767 torch.randn(shape, dtype=dtype, device=device) for _ in range(2) 768 ) 769 inputs = [ 770 torch.randn(shape, dtype=dtype, device=device), 771 torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), 772 ] 773 outs = [ 774 None, 775 torch.randn(shape, dtype=dtype, device=device), 776 torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), 777 ] 778 options = itertools.product(inputs, outs) 779 for input, out in options: 780 y_ref = torch.bmm(mat1, mat2) 781 y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) 782 self.assertEqual(y_ref, y) 783 784 @dtypes(torch.float) 785 def test_addmm_sizes(self, device, dtype): 786 for m in [0, 1, 25]: 787 for n in [0, 1, 10]: 788 for k in [0, 1, 8]: 789 M = torch.randn(n, m, device=device).to(dtype) 790 m1 = torch.randn(n, k, device=device).to(dtype) 791 m2 = torch.randn(k, m, device=device).to(dtype) 792 self._test_addmm_addmv(torch.addmm, M, m1, m2) 793 794 m1 = torch.randn(n, k + 1, device=device).to(dtype) 795 m2 = torch.randn(k, m, device=device).to(dtype) 796 self.assertRaisesRegex( 797 RuntimeError, 798 f"{n}x{k + 1}.*{k}x{m}", 799 lambda: torch.addmm(M, m1, m2), 800 ) 801 self.assertRaisesRegex( 802 RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2) 803 ) 804 805 @precisionOverride( 806 { 807 torch.double: 1e-8, 808 torch.float: 1e-4, 809 torch.bfloat16: 5e-2, 810 torch.half: 5e-2, 811 torch.cfloat: 1e-4, 812 torch.cdouble: 1e-8, 813 } 814 ) 815 @dtypes(torch.float32, torch.bfloat16, torch.half) 816 def test_addmm_gelu(self, device, dtype): 817 self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) 818 819 @precisionOverride( 820 { 821 torch.double: 1e-8, 822 torch.float: 1e-4, 823 torch.bfloat16: 5e-2, 824 torch.half: 5e-2, 825 torch.cfloat: 1e-4, 826 torch.cdouble: 1e-8, 827 } 828 ) 829 @dtypes(torch.float32, torch.bfloat16, torch.half) 830 def test_addmm_relu(self, device, dtype): 831 self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 832 833 @dtypes(torch.float, torch.bfloat16, torch.half) 834 def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): 835 # tests (o, s)*(s). o is output size, s is summed size. 836 o = 5 837 s = 3 838 a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) 839 x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) 840 y_data = torch.ones(o, device=device, dtype=dtype) 841 control = torch.tensor( 842 [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype 843 ) 844 845 def _test(row_major, incx, incy, lda_tail): 846 if row_major: 847 a_storage = torch.full( 848 (o, s + lda_tail), float("nan"), device=device, dtype=dtype 849 ) 850 else: 851 a_storage = torch.full( 852 (s, o + lda_tail), float("nan"), device=device, dtype=dtype 853 ).permute(1, 0) 854 a = a_storage[:o, :s].copy_(a_data) 855 856 x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype) 857 x = x_storage[:, 0].copy_(x_data) 858 859 y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype) 860 y = y_storage[:, 0].copy_(y_data) 861 862 self._test_addmm_addmv(torch.addmv, y, a, x) 863 864 for row_major, incx, incy, lda_tail in itertools.product( 865 (False, True), (1, 2), (1, 2), (0, 1) 866 ): 867 _test(row_major, incx, incy, lda_tail) 868 869 @precisionOverride( 870 { 871 torch.double: 1e-8, 872 torch.float: 1e-4, 873 torch.bfloat16: 0.6, 874 torch.half: 1e-1, 875 torch.cfloat: 1e-4, 876 torch.cdouble: 1e-8, 877 } 878 ) 879 @dtypes(torch.bfloat16, torch.half, torch.float32) 880 def test_corner_cases_of_cublasltmatmul(self, device, dtype): 881 # common case 882 M = torch.randn(128, device=device).to(dtype) 883 m1 = torch.randn(2048, 2400, device=device).to(dtype) 884 m2 = torch.randn(128, 2400, device=device).to(dtype) 885 torch.nn.functional.linear(m1, m2, M) 886 # Ntrans_B has ld >> rows 887 m1 = torch.rand([128, 2400]).to(dtype).to(device).t() 888 m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] 889 M = torch.rand([128]).to(dtype).to(device) 890 torch.addmm(M, m2.t(), m1) 891 # trans_A has ld >> rows 892 m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() 893 m2 = torch.randn(2048, 2400, device=device).to(dtype) 894 M = torch.rand([128]).to(dtype).to(device) 895 torch.addmm(M, m2, m1) 896 # large tensor dim > 65535 897 M = torch.randn(16, device=device).to(dtype) 898 m1 = torch.randn(32, 131071, device=device).to(dtype) 899 m2 = torch.randn(16, 131071, device=device).to(dtype) 900 torch.nn.functional.linear(m1, m2, M) 901 902 def test_blas_empty(self, device): 903 def fn(torchfn, *args, test_out=False, **kwargs): 904 def call_torch_fn(*args, **kwargs): 905 return torchfn( 906 *tuple( 907 torch.randn(shape, device=device) 908 if isinstance(shape, tuple) 909 else shape 910 for shape in args 911 ), 912 **kwargs, 913 ) 914 915 result = call_torch_fn(*args, **kwargs) 916 if not test_out: 917 return result 918 else: 919 out = torch.full_like(result, math.nan) 920 out1 = call_torch_fn(*args, **kwargs, out=out) 921 return out 922 923 # mm, addmm 924 self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) 925 self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) 926 self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) 927 self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) 928 self.assertEqual( 929 torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)) 930 ) 931 self.assertEqual( 932 torch.zeros((5, 6), device=device), 933 fn(torch.mm, (5, 0), (0, 6), test_out=True), 934 ) 935 936 self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) 937 self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape) 938 t = torch.randn((5, 6), device=device) 939 self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) 940 self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) 941 942 # mv, addmv 943 self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) 944 self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) 945 self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) 946 self.assertEqual( 947 torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True) 948 ) 949 950 self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) 951 t = torch.randn((3,), device=device) 952 self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) 953 self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) 954 955 # bmm, baddbmm 956 self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) 957 self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) 958 self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) 959 self.assertEqual( 960 torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)) 961 ) 962 self.assertEqual( 963 torch.zeros((3, 5, 6), device=device), 964 fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True), 965 ) 966 967 self.assertEqual( 968 (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape 969 ) 970 self.assertEqual( 971 (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape 972 ) 973 self.assertEqual( 974 (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape 975 ) 976 self.assertEqual( 977 (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape 978 ) 979 c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) 980 self.assertEqual( 981 -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2) 982 ) # Issue #33467 983 self.assertEqual( 984 -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True) 985 ) # Issue #33467 986 987 # addbmm 988 self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) 989 self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) 990 t = torch.randn((5, 6), device=device) 991 self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) 992 self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) 993 994 # matmul 995 self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,))) 996 self.assertEqual( 997 torch.tensor(0.0, device=device), 998 fn(torch.matmul, (0,), (0,), test_out=True), 999 ) 1000 self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) 1001 self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) 1002 self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) 1003 self.assertEqual( 1004 torch.zeros((5, 3, 4), device=device), 1005 fn(torch.matmul, (5, 3, 0), (5, 0, 4)), 1006 ) 1007 self.assertEqual( 1008 torch.zeros((5, 3, 4), device=device), 1009 fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True), 1010 ) 1011 1012 # dot 1013 self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,))) 1014 self.assertEqual( 1015 torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True) 1016 ) 1017 1018 def test_large_bmm_backward(self, device): 1019 A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT 1020 B = torch.randn([1, 1024, 65536], device=device, requires_grad=True) 1021 G = torch.randn([1024, 2, 65536], device=device) 1022 1023 # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 1024 (A @ B).backward(G) 1025 1026 def test_large_bmm_mm_backward(self, device): 1027 A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT 1028 B = torch.randn([1024, 65536], device=device, requires_grad=True) 1029 G = torch.randn([1024, 2, 65536], device=device) 1030 1031 # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 1032 (A @ B).backward(G) 1033 1034 def check_single_matmul(self, x, y): 1035 def assertEqual(answer, expected): 1036 if x.dtype.is_floating_point or x.dtype.is_complex: 1037 k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix 1038 self.assertEqual( 1039 answer, 1040 expected, 1041 msg=f"{x.shape} x {y.shape} = {answer.shape}", 1042 atol=k * 5e-5, 1043 rtol=1e-4, 1044 ) 1045 else: 1046 self.assertEqual( 1047 answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}" 1048 ) 1049 1050 # test x @ y 1051 expected = np.matmul(x.cpu(), y.cpu()) 1052 ans = torch.matmul(x, y) 1053 self.assertTrue(ans.is_contiguous()) 1054 assertEqual(ans, expected) 1055 1056 # test out 1057 out = torch.empty_like(ans) 1058 ans = torch.matmul(x, y, out=out) 1059 self.assertIs(ans, out) 1060 self.assertTrue(ans.is_contiguous()) 1061 assertEqual(ans, expected) 1062 1063 def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): 1064 """ 1065 Generates sequences of tuples (x, y) of with size(x) = x_dim and 1066 size(y) <= y_dim that are compatible wrt. matmul 1067 """ 1068 assert x_dim >= 1 1069 assert y_dim >= 2 1070 x = x_dim 1071 for y in range(1, y_dim + 1): 1072 for batch, mn in product( 1073 product(range(batch_size), repeat=max(x - 2, y - 2, 0)), 1074 product(range(matrix_size), repeat=min(y, 2)), 1075 ): 1076 if x == 1: 1077 size_x = mn[:1] 1078 size_y = batch + mn 1079 yield size_x, size_y 1080 else: 1081 for k in range(matrix_size): 1082 size_x = (k,) + mn[:1] 1083 if x > 2: 1084 size_x = batch[-(x - 2) :] + size_x 1085 size_y = mn 1086 if y > 2: 1087 size_y = batch[-(y - 2) :] + size_y 1088 yield size_x, size_y 1089 1090 @dtypes(torch.float) 1091 def test_matmul_small_brute_force_1d_Nd(self, device, dtype): 1092 make_arg = partial(make_tensor, device=device, dtype=dtype) 1093 1094 for (size_x, size_y), nctg_x, nctg_y in product( 1095 self.gen_sizes_matmul(1), (True, False), (True, False) 1096 ): 1097 x = make_arg(size_x, noncontiguous=nctg_x) 1098 y = make_arg(size_y, noncontiguous=nctg_y) 1099 self.check_single_matmul(x, y) 1100 1101 @dtypes(torch.float) 1102 def test_matmul_small_brute_force_2d_Nd(self, device, dtype): 1103 make_arg = partial(make_tensor, device=device, dtype=dtype) 1104 1105 for (size_x, size_y), nctg_x, nctg_y in product( 1106 self.gen_sizes_matmul(2), (True, False), (True, False) 1107 ): 1108 x = make_arg(size_x, noncontiguous=nctg_x) 1109 y = make_arg(size_y, noncontiguous=nctg_y) 1110 self.check_single_matmul(x, y) 1111 1112 @dtypes(torch.float) 1113 def test_matmul_small_brute_force_3d_Nd(self, device, dtype): 1114 make_arg = partial(make_tensor, device=device, dtype=dtype) 1115 1116 for (size_x, size_y), nctg_x, nctg_y in product( 1117 self.gen_sizes_matmul(3), (True, False), (True, False) 1118 ): 1119 x = make_arg(size_x, noncontiguous=nctg_x) 1120 y = make_arg(size_y, noncontiguous=nctg_y) 1121 self.check_single_matmul(x, y) 1122 1123 @dtypes(torch.float) 1124 def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): 1125 a = torch.empty( 1126 (256, 512), device=device, dtype=dtype, requires_grad=True 1127 ).unsqueeze(0) 1128 b = torch.empty( 1129 (4, 128, 512), device=device, dtype=dtype, requires_grad=True 1130 ).transpose(-1, -2) 1131 c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) 1132 1133 torch.matmul(a.detach(), b.detach(), out=c) 1134 1135 with self.assertRaisesRegex( 1136 RuntimeError, 1137 "functions with out=... arguments don't support automatic differentiation", 1138 ): 1139 torch.matmul(a, b, out=c) 1140 1141 with torch.no_grad(): 1142 torch.matmul(a, b, out=c) 1143 1144 1145instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) 1146 1147if __name__ == "__main__": 1148 run_tests() 1149