xref: /aosp_15_r20/external/pytorch/test/xpu/test_gemm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: intel"]
2
3import itertools
4import math
5import random
6from functools import partial
7from itertools import product
8
9import numpy as np
10
11import torch
12from torch.testing import make_tensor
13from torch.testing._internal.common_device_type import (
14    dtypes,
15    instantiate_device_type_tests,
16    precisionOverride,
17)
18from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase
19
20
21class TestBasicGEMM(TestCase):
22    def _test_addmm_addmv(
23        self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
24    ):
25        dtype = t.dtype
26        numpy_dtype = dtype
27        if dtype in {torch.bfloat16, torch.half}:
28            numpy_dtype = torch.float
29        if dtype.is_complex:
30            alpha = 0.9 + 0.3j if alpha is None else alpha
31            beta = 0.5 + 0.6j if beta is None else beta
32        else:
33            alpha = 1.2 if alpha is None else alpha
34            beta = 0.8 if beta is None else beta
35        if activation == "gelu":
36            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
37        else:
38            res1 = f(t, m, v, alpha=alpha, beta=beta)
39        res2 = torch.full_like(res1, math.nan)
40        if transpose_out:
41            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
42        if activation == "gelu":
43            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
44        else:
45            f(t, m, v, alpha=alpha, beta=beta, out=res2)
46        m.to(numpy_dtype).cpu().numpy()
47        v.to(numpy_dtype).cpu().numpy()
48        res3 = alpha * (
49            m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()
50        )
51        if beta != 0:
52            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
53        if activation == "relu":
54            res3 = res3 * (res3 > 0)
55        elif activation == "gelu":
56            res3_t = torch.from_numpy(res3).to(dtype)
57            approximate = "tanh" if t.is_cuda else "none"
58            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
59            res3 = res3_t.to(numpy_dtype).cpu().numpy()
60        else:
61            assert activation is None, f"unsupported activation {activation}"
62        res3 = torch.from_numpy(res3).to(dtype)
63        self.assertEqual(res1, res2)
64        self.assertEqual(res1, res3)
65
66    def _test_addmm_impl(self, func, activation, device, dtype):
67        M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
68        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
69        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
70        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
71
72        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
73        V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device)
74        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
75
76        # Test 0-strided
77        M = (
78            torch.randn(10, 1, device="cpu", dtype=torch.float32)
79            .to(dtype)
80            .expand(10, 25)
81            .to(device)
82        )
83        m1 = (
84            torch.randn(10, 1, device="cpu", dtype=torch.float32)
85            .to(dtype)
86            .expand(10, 50)
87            .to(device)
88        )
89        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
90        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
91
92        # Test beta=0, M=nan
93        M = (
94            torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32)
95            .to(dtype)
96            .to(device)
97        )
98        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
99        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
100        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
101
102        # Test transpose
103        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
104
105            def maybe_transpose(cond, m):
106                if not cond:
107                    return m
108                return m.t().clone(memory_format=torch.contiguous_format).t()
109
110            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
111            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
112            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
113            self._test_addmm_addmv(
114                func, M, m1, m2, transpose_out=t4, activation=activation
115            )
116
117            if t1:
118                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
119                self._test_addmm_addmv(
120                    func,
121                    V,
122                    m1,
123                    m2,
124                    beta=1,
125                    transpose_out=t4,
126                    activation=activation,
127                )
128
129    @precisionOverride(
130        {
131            torch.float: 1e-4,
132            torch.half: 1e-1,
133        }
134    )
135    @dtypes(torch.float32, torch.half)
136    def test_addmm(self, device, dtype):
137        self._test_addmm_impl(torch.addmm, None, device, dtype)
138
139    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
140    @dtypes(torch.bfloat16, torch.half, torch.float)
141    def test_addmv(self, device, dtype):
142        # have to use torch.randn(...).to(bfloat16) instead of
143        # torch.randn(..., dtype=bfloat16). randn does not support
144        # bfloat16 yet.
145        # "*0.2" to reduce errors for low precision
146        ts = [
147            0.2 * torch.randn(50, device=device).to(dtype),
148            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
149        ]
150        vs = [
151            0.2 * torch.randn(100, device=device).to(dtype),
152            0.2
153            * torch.ones(1, device=device)
154            .to(dtype)
155            .expand(100),  # to reduce errors for low precision
156        ]
157        ms = [
158            # 0d
159            0.2
160            * torch.ones((), device=device)
161            .to(dtype)
162            .expand(50, 100),  # to reduce errors for low precision
163            # 1d
164            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
165            # this initialization reduces errors for low precision for broadcasted matrices
166            # by making sure that intermediate and result values are exactly representable
167            # in low precision type
168            0.2
169            * torch.randint(3, (50, 1), dtype=torch.float, device=device)
170            .to(dtype)
171            .expand(50, 100),
172            # 2d
173            0.2 * torch.randn((50, 100), device=device).to(dtype),
174            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
175        ]
176        for m, v, t in itertools.product(ms, vs, ts):
177            self._test_addmm_addmv(torch.addmv, t, m, v)
178        # Test beta=0, t=nan
179        t = torch.full((50,), math.nan, device=device).to(dtype)
180        for m, v in itertools.product(ms, vs):
181            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
182
183    @dtypes(
184        torch.half,
185        torch.float32,
186    )
187    def test_mm(self, device, dtype):
188        def _test_mm(n, m, p, dtype, genf):
189            # helper function
190            def matrixmultiply(mat1, mat2):
191                n = mat1.size(0)
192                m = mat1.size(1)
193                p = mat2.size(1)
194                dtype_ = torch.float if dtype == torch.half else dtype
195                if dtype == torch.half:
196                    mat1 = mat1.float()
197                    mat2 = mat2.float()
198                res = torch.zeros(n, p, dtype=dtype_, device=device)
199                for i, j in iter_indices(res):
200                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
201                return res.half() if dtype == torch.half else res
202
203            # contiguous case
204            mat1 = genf(n, m)
205            mat2 = genf(m, p)
206            res = torch.mm(mat1, mat2)
207
208            res2 = matrixmultiply(mat1, mat2)
209            self.assertEqual(res, res2)
210
211            # non contiguous case 1
212            mat1 = genf(n, m)
213            mat2 = genf(p, m).t()
214            res = torch.mm(mat1, mat2)
215
216            res2 = matrixmultiply(mat1, mat2)
217            self.assertEqual(res, res2)
218
219            # non contiguous case 2
220            mat1 = genf(m, n).t()
221            mat2 = genf(m, p)
222            res = torch.mm(mat1, mat2)
223
224            res2 = matrixmultiply(mat1, mat2)
225            self.assertEqual(res, res2)
226
227            # non contiguous case 3
228            mat1 = genf(m, n).t()
229            mat2 = genf(p, m).t()
230            res = torch.mm(mat1, mat2)
231
232            res2 = matrixmultiply(mat1, mat2)
233            self.assertEqual(res, res2)
234
235            # test with zero stride
236            mat1 = genf(n, m)
237            mat2 = genf(m, 1).expand(m, p)
238            res = torch.mm(mat1, mat2)
239
240            res2 = matrixmultiply(mat1, mat2)
241            self.assertEqual(res, res2)
242
243            # explicitly exercise the _out variant in torch.mm().
244            # contiguous case
245            mat1 = genf(n, m)
246            mat2 = genf(m, p)
247            res = genf(n, p)
248            torch.mm(mat1, mat2, out=res)
249
250            res2 = matrixmultiply(mat1, mat2)
251            self.assertEqual(res, res2)
252
253            # explicitly exercise the _out variant in torch.mm().
254            # non contiguous case 3
255            mat1 = genf(m, n).t()
256            mat2 = genf(p, m).t()
257            res = genf(n, p)
258            torch.mm(mat1, mat2, out=res)
259
260            res2 = matrixmultiply(mat1, mat2)
261            self.assertEqual(res, res2)
262
263        def genf_int(x, y):
264            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
265
266        def genf_bfloat(x, y):
267            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
268
269        def genf_float(x, y):
270            return torch.randn(x, y, dtype=dtype, device=device)
271
272        def genf_Half(x, y):
273            return torch.randn(x, y, dtype=dtype, device=device)
274
275        for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
276            if (dtype == torch.int32) or (dtype == torch.int64):
277                genf = genf_int
278            elif dtype == torch.bfloat16:
279                genf = genf_bfloat
280            elif dtype == torch.half:
281                genf = genf_Half
282            else:
283                genf = genf_float
284
285            _test_mm(n, m, p, dtype, genf)
286
287    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
288    @dtypes(torch.float32, torch.bfloat16, torch.half)
289    def test_bmm(self, device, dtype):
290        batch_sizes = [1, 10]
291        M, N, O = 23, 15, 12
292        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
293
294        def invert_perm(p):
295            d = {x: i for i, x in enumerate(p)}
296            return (d[0], d[1], d[2])
297
298        def generate_inputs(num_batches):
299            # transposed tensors
300            for perm1, perm2 in itertools.product(
301                itertools.permutations((0, 1, 2)), repeat=2
302            ):
303                b1 = make_tensor(
304                    (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1
305                )
306                b2 = make_tensor(
307                    (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1
308                )
309                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
310                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
311                yield b1, b2
312            # broadcasting tensors
313            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
314                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
315                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
316                b1 = make_tensor(
317                    shape1, dtype=dtype, device=device, low=-0.1, high=0.1
318                ).expand(num_batches, M, N)
319                b2 = make_tensor(
320                    shape2, dtype=dtype, device=device, low=-0.1, high=0.1
321                ).expand(num_batches, N, O)
322                yield b1, b2
323            # zero-sized tensors
324            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
325                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
326                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
327                b1 = torch.randn(shape1, dtype=dtype, device=device)
328                b2 = torch.randn(shape2, dtype=dtype, device=device)
329                yield b1, b2
330
331        for num_batches in batch_sizes:
332            for (b1, b2), perm3 in itertools.product(
333                generate_inputs(num_batches), itertools.permutations((0, 1, 2))
334            ):
335                res1 = torch.bmm(b1, b2)
336                res2 = (
337                    torch.full(
338                        (num_batches, M, O), math.nan, dtype=dtype, device=device
339                    )
340                    .permute(perm3)
341                    .contiguous()
342                    .permute(invert_perm(perm3))
343                )
344                torch.bmm(b1, b2, out=res2)
345                expect = torch.from_numpy(
346                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
347                ).to(device=device, dtype=dtype)
348                self.assertEqual(expect, res1)
349                self.assertEqual(expect, res2)
350
351                if self.device_type == "cuda":
352                    # check that mixed arguments are rejected
353                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
354                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
355                    self.assertRaises(
356                        RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())
357                    )
358
359    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
360        getattr(out_tensor, func + "_")(b1, b2)
361        self.assertEqual(out_tensor, ref)
362        res3 = out_tensor.clone()
363
364        with self.assertWarnsOnceRegex(
365            UserWarning, f"This overload of {func}_ is deprecated"
366        ):
367            getattr(out_tensor, func + "_")(1, b1, b2)
368        self.assertEqual(out_tensor, ref * 2),
369        getattr(res3, func + "_")(b1, b2, beta=1)
370        self.assertEqual(out_tensor, res3)
371
372        with self.assertWarnsOnceRegex(
373            UserWarning, f"This overload of {func}_ is deprecated"
374        ):
375            getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2)
376        self.assertEqual(out_tensor, ref * 2.5)
377        getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5)
378        self.assertEqual(out_tensor, res3)
379
380        with self.assertWarnsOnceRegex(
381            UserWarning, f"This overload of {func} is deprecated"
382        ):
383            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
384
385        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5)
386        self.assertEqual(res4, ref * 3),
387
388        nan = torch.full_like(out_tensor, math.nan)
389        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
390        self.assertEqual(res5, ref)
391
392        if b1.is_complex():
393            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j)
394            self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref)
395        else:
396            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5)
397            self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref)
398
399        res7 = torch.full_like(out_tensor, math.nan)
400        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
401        self.assertEqual(res7, ref)
402
403    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
404    @dtypes(torch.float32, torch.bfloat16, torch.half)
405    def test_addbmm(self, device, dtype):
406        num_batches = 2
407        M, N, O = 16, 17, 18
408
409        is_supported = True
410
411        if not is_supported:
412            b1 = make_tensor(
413                (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
414            )
415            b2 = make_tensor(
416                (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
417            )
418            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
419            self.assertRaisesRegex(
420                RuntimeError,
421                "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
422                lambda: torch.addbmm(t, b1, b2),
423            )
424            return
425
426        def invert_perm(p):
427            d = {x: i for i, x in enumerate(p)}
428            return (d[0], d[1], d[2])
429
430        def generate_tensor():
431            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
432            # transposed tensors
433            for perm1, perm2 in itertools.product(
434                itertools.permutations((0, 1, 2)), repeat=2
435            ):
436                for perm3 in itertools.permutations((0, 1)):
437                    b1 = (
438                        make_tensor(
439                            (num_batches, M, N),
440                            dtype=dtype,
441                            device=device,
442                            low=-1,
443                            high=1,
444                        )
445                        * 0.1
446                    )
447                    b2 = (
448                        make_tensor(
449                            (num_batches, N, O),
450                            dtype=dtype,
451                            device=device,
452                            low=-1,
453                            high=1,
454                        )
455                        * 0.1
456                    )
457                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
458                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
459                    ref = (
460                        torch.from_numpy(
461                            b1.to(numpy_dtype).cpu().numpy()
462                            @ b2.to(numpy_dtype).cpu().numpy()
463                        )
464                        .to(device=device, dtype=dtype)
465                        .sum(0)
466                    )
467                    out_tensor = (
468                        torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
469                    )
470                    yield b1, b2, ref, out_tensor
471            # broadcasting tensors
472            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
473                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
474                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
475                b1 = (
476                    make_tensor(
477                        shape1, dtype=dtype, device=device, low=-1, high=1
478                    ).expand(num_batches, M, N)
479                    * 0.1
480                )
481                b2 = (
482                    make_tensor(
483                        shape2, dtype=dtype, device=device, low=-1, high=1
484                    ).expand(num_batches, N, O)
485                    * 0.1
486                )
487                ref = (
488                    torch.from_numpy(
489                        b1.to(numpy_dtype).cpu().numpy()
490                        @ b2.to(numpy_dtype).cpu().numpy()
491                    )
492                    .to(device=device, dtype=dtype)
493                    .sum(0)
494                )
495                out_tensor = torch.zeros_like(ref)
496                yield b1, b2, ref, out_tensor
497            # zero-sized tensors
498            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
499                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
500                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
501                b1 = (
502                    make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1)
503                    * 0.1
504                )
505                b2 = (
506                    make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1)
507                    * 0.1
508                )
509                ref = (
510                    torch.from_numpy(
511                        b1.to(numpy_dtype).cpu().numpy()
512                        @ b2.to(numpy_dtype).cpu().numpy()
513                    )
514                    .to(device=device, dtype=dtype)
515                    .sum(0)
516                )
517                out_tensor = torch.zeros_like(ref)
518                yield b1, b2, ref, out_tensor
519
520        for b1, b2, ref, out_tensor in generate_tensor():
521            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
522
523    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
524    @dtypes(torch.float32, torch.bfloat16, torch.half)
525    def test_baddbmm(self, device, dtype):
526        num_batches = 10
527        M, N, O = 12, 8, 50
528
529        def invert_perm(p):
530            d = {x: i for i, x in enumerate(p)}
531            return (d[0], d[1], d[2])
532
533        def generate_tensor():
534            numpy_dtype = (
535                dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
536            )
537            # transposed tensors
538            for perm1, perm2, perm3 in itertools.product(
539                itertools.permutations((0, 1, 2)), repeat=3
540            ):
541                b1 = make_tensor(
542                    (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
543                )
544                b2 = make_tensor(
545                    (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
546                )
547                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
548                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
549                ref = torch.from_numpy(
550                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
551                ).to(device=device, dtype=dtype)
552                out_tensor = torch.zeros_like(ref)
553                out_tensor = (
554                    out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
555                )
556                yield b1, b2, ref, out_tensor
557            # broadcasting tensors
558            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
559                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
560                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
561                b1 = make_tensor(
562                    shape1, dtype=dtype, device=device, low=-1, high=1
563                ).expand(num_batches, M, N)
564                b2 = make_tensor(
565                    shape2, dtype=dtype, device=device, low=-1, high=1
566                ).expand(num_batches, N, O)
567                ref = torch.from_numpy(
568                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
569                ).to(device=device, dtype=dtype)
570                out_tensor = torch.zeros_like(ref)
571                yield b1, b2, ref, out_tensor
572            # zero-sized tensors
573            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
574                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
575                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
576                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
577                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
578                ref = torch.from_numpy(
579                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
580                ).to(device=device, dtype=dtype)
581                out_tensor = torch.zeros_like(ref)
582                yield b1, b2, ref, out_tensor
583
584        for b1, b2, ref, out_tensor in generate_tensor():
585            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
586
587    def test_tensordot(self, device):
588        a = torch.arange(60.0, device=device).reshape(3, 4, 5)
589        b = torch.arange(24.0, device=device).reshape(4, 3, 2)
590        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
591        cn = torch.from_numpy(
592            np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1]))
593        )
594        self.assertEqual(c, cn)
595
596        cout = torch.zeros((5, 2), device=device)
597        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
598        self.assertEqual(c, cout)
599
600        a = torch.randn(2, 3, 4, 5, device=device)
601        b = torch.randn(4, 5, 6, 7, device=device)
602        c = torch.tensordot(a, b, dims=2).cpu()
603        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2))
604
605        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
606            torch.tensordot(a, b, dims=-1)
607
608        self.assertEqual(c, cn)
609        c = torch.tensordot(a, b).cpu()
610        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
611        self.assertEqual(c, cn)
612
613        a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0)
614        an = torch.from_numpy(
615            np.tensordot(
616                np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0
617            )
618        )
619        self.assertEqual(a, an)
620
621    @dtypes(torch.float)
622    @precisionOverride({torch.float32: 1e-4})
623    def test_1_sized_with_0_strided(self, device, dtype):
624        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
625        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
626        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
627        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
628        res = torch.bmm(a_strided, b_strided)
629        expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(
630            device=device, dtype=dtype
631        )
632        self.assertEqual(expect, res)
633
634    def _select_broadcastable_dims(self, dims_full=None):
635        # select full dimensionality
636        if dims_full is None:
637            dims_full = []
638            ndims = random.randint(1, 4)
639            dims_full = [random.randint(1, 8) for _ in range(ndims)]
640        else:
641            ndims = len(dims_full)
642
643        # select actual dimensions for ops:
644        # larger: full ndims, individual sizes may be reduced
645        # smaller: possibly reduced ndims, sizes may be reduced
646        smaller_ndims = random.randint(1, ndims)
647        dims_small = []
648        dims_large = []
649        for i in range(ndims - 1, -1, -1):
650            j = random.randint(1, 3)
651            if j == 1:  # no reduced singleton dimension
652                ds = dims_full[i]
653                dl = dims_full[i]
654            elif j == 2:  # larger may have reduced singleton dimension
655                ds = dims_full[i]
656                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
657            elif j == 3:  # smaller may have reduced singleton dimension
658                ds = 1
659                dl = dims_full[i]
660            dims_large = [dl] + dims_large
661            if len(dims_small) < smaller_ndims:
662                dims_small = [ds] + dims_small
663        return (dims_small, dims_large, dims_full)
664
665    def test_broadcast_fused_matmul(self, device):
666        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
667
668        for fn in fns:
669            batch_dim = random.randint(1, 8)
670            n_dim = random.randint(1, 8)
671            m_dim = random.randint(1, 8)
672            p_dim = random.randint(1, 8)
673
674            def dims_full_for_fn():
675                if fn == "baddbmm":
676                    return (
677                        [batch_dim, n_dim, p_dim],
678                        [batch_dim, n_dim, m_dim],
679                        [batch_dim, m_dim, p_dim],
680                    )
681                elif fn == "addbmm":
682                    return (
683                        [n_dim, p_dim],
684                        [batch_dim, n_dim, m_dim],
685                        [batch_dim, m_dim, p_dim],
686                    )
687                elif fn == "addmm":
688                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
689                elif fn == "addmv":
690                    return ([n_dim], [n_dim, m_dim], [m_dim])
691                elif fn == "addr":
692                    return ([n_dim, m_dim], [n_dim], [m_dim])
693                else:
694                    raise AssertionError("unknown function")
695
696            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
697            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
698
699            t0_small = torch.randn(*t0_dims_small, device=device).float()
700            t1 = torch.randn(*t1_dims, device=device).float()
701            t2 = torch.randn(*t2_dims, device=device).float()
702
703            t0_full = t0_small.expand(*t0_dims_full).to(device)
704
705            fntorch = getattr(torch, fn)
706            r0 = fntorch(t0_small, t1, t2)
707            r1 = fntorch(t0_full, t1, t2)
708            self.assertEqual(r0, r1)
709
710    @dtypes(torch.float32)
711    def test_strided_mm_bmm(self, device, dtype):
712        # Tests strided view case with stride smaller than corresponding dimension size
713        x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
714        new_shape = [2, 2, 2]
715        new_stride = [3, 1, 1]
716        sx = torch.as_strided(x, size=new_shape, stride=new_stride)
717
718        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
719        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
720        self.compare_with_numpy(torch_fn, np_fn, sx)
721
722        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
723        self.compare_with_numpy(torch_fn, np_fn, sx[0])
724
725    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
726        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
727        b = torch.randn(10, 20, dtype=torch.float32, device=device)
728        with self.assertRaisesRegex(
729            RuntimeError, "expected .* and .* to have the same dtype, but got:"
730        ):
731            torch.mm(a, b)
732
733    def test_matmul_45724(self, device):
734        # https://github.com/pytorch/pytorch/issues/45724
735        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
736        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
737        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
738        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half()
739        torch.matmul(a, b, out=c)
740        self.assertEqual(c, cpu_result)
741
742    @dtypes(
743        torch.int16,
744        torch.int32,
745        torch.int64,
746        torch.float16,
747        torch.float32,
748        torch.float64,
749    )
750    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
751        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
752        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
753        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
754        if dtype != torch.float32:
755            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
756                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
757        else:
758            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
759            y_ref = torch.bmm(batch1, batch2)
760            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
761            self.assertEqual(out, y_ref)
762
763    @dtypes(torch.float)
764    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
765        for shape in [[3, 2, 2], [2, 20, 20]]:
766            mat1, mat2 = (
767                torch.randn(shape, dtype=dtype, device=device) for _ in range(2)
768            )
769            inputs = [
770                torch.randn(shape, dtype=dtype, device=device),
771                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
772            ]
773            outs = [
774                None,
775                torch.randn(shape, dtype=dtype, device=device),
776                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
777            ]
778            options = itertools.product(inputs, outs)
779            for input, out in options:
780                y_ref = torch.bmm(mat1, mat2)
781                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
782                self.assertEqual(y_ref, y)
783
784    @dtypes(torch.float)
785    def test_addmm_sizes(self, device, dtype):
786        for m in [0, 1, 25]:
787            for n in [0, 1, 10]:
788                for k in [0, 1, 8]:
789                    M = torch.randn(n, m, device=device).to(dtype)
790                    m1 = torch.randn(n, k, device=device).to(dtype)
791                    m2 = torch.randn(k, m, device=device).to(dtype)
792                    self._test_addmm_addmv(torch.addmm, M, m1, m2)
793
794                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
795                    m2 = torch.randn(k, m, device=device).to(dtype)
796                    self.assertRaisesRegex(
797                        RuntimeError,
798                        f"{n}x{k + 1}.*{k}x{m}",
799                        lambda: torch.addmm(M, m1, m2),
800                    )
801                    self.assertRaisesRegex(
802                        RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)
803                    )
804
805    @precisionOverride(
806        {
807            torch.double: 1e-8,
808            torch.float: 1e-4,
809            torch.bfloat16: 5e-2,
810            torch.half: 5e-2,
811            torch.cfloat: 1e-4,
812            torch.cdouble: 1e-8,
813        }
814    )
815    @dtypes(torch.float32, torch.bfloat16, torch.half)
816    def test_addmm_gelu(self, device, dtype):
817        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
818
819    @precisionOverride(
820        {
821            torch.double: 1e-8,
822            torch.float: 1e-4,
823            torch.bfloat16: 5e-2,
824            torch.half: 5e-2,
825            torch.cfloat: 1e-4,
826            torch.cdouble: 1e-8,
827        }
828    )
829    @dtypes(torch.float32, torch.bfloat16, torch.half)
830    def test_addmm_relu(self, device, dtype):
831        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
832
833    @dtypes(torch.float, torch.bfloat16, torch.half)
834    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
835        # tests (o, s)*(s).  o is output size, s is summed size.
836        o = 5
837        s = 3
838        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
839        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
840        y_data = torch.ones(o, device=device, dtype=dtype)
841        control = torch.tensor(
842            [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype
843        )
844
845        def _test(row_major, incx, incy, lda_tail):
846            if row_major:
847                a_storage = torch.full(
848                    (o, s + lda_tail), float("nan"), device=device, dtype=dtype
849                )
850            else:
851                a_storage = torch.full(
852                    (s, o + lda_tail), float("nan"), device=device, dtype=dtype
853                ).permute(1, 0)
854            a = a_storage[:o, :s].copy_(a_data)
855
856            x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype)
857            x = x_storage[:, 0].copy_(x_data)
858
859            y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype)
860            y = y_storage[:, 0].copy_(y_data)
861
862            self._test_addmm_addmv(torch.addmv, y, a, x)
863
864        for row_major, incx, incy, lda_tail in itertools.product(
865            (False, True), (1, 2), (1, 2), (0, 1)
866        ):
867            _test(row_major, incx, incy, lda_tail)
868
869    @precisionOverride(
870        {
871            torch.double: 1e-8,
872            torch.float: 1e-4,
873            torch.bfloat16: 0.6,
874            torch.half: 1e-1,
875            torch.cfloat: 1e-4,
876            torch.cdouble: 1e-8,
877        }
878    )
879    @dtypes(torch.bfloat16, torch.half, torch.float32)
880    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
881        # common case
882        M = torch.randn(128, device=device).to(dtype)
883        m1 = torch.randn(2048, 2400, device=device).to(dtype)
884        m2 = torch.randn(128, 2400, device=device).to(dtype)
885        torch.nn.functional.linear(m1, m2, M)
886        # Ntrans_B has ld >> rows
887        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
888        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
889        M = torch.rand([128]).to(dtype).to(device)
890        torch.addmm(M, m2.t(), m1)
891        # trans_A has ld >> rows
892        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
893        m2 = torch.randn(2048, 2400, device=device).to(dtype)
894        M = torch.rand([128]).to(dtype).to(device)
895        torch.addmm(M, m2, m1)
896        # large tensor dim > 65535
897        M = torch.randn(16, device=device).to(dtype)
898        m1 = torch.randn(32, 131071, device=device).to(dtype)
899        m2 = torch.randn(16, 131071, device=device).to(dtype)
900        torch.nn.functional.linear(m1, m2, M)
901
902    def test_blas_empty(self, device):
903        def fn(torchfn, *args, test_out=False, **kwargs):
904            def call_torch_fn(*args, **kwargs):
905                return torchfn(
906                    *tuple(
907                        torch.randn(shape, device=device)
908                        if isinstance(shape, tuple)
909                        else shape
910                        for shape in args
911                    ),
912                    **kwargs,
913                )
914
915            result = call_torch_fn(*args, **kwargs)
916            if not test_out:
917                return result
918            else:
919                out = torch.full_like(result, math.nan)
920                out1 = call_torch_fn(*args, **kwargs, out=out)
921                return out
922
923        # mm, addmm
924        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
925        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
926        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
927        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
928        self.assertEqual(
929            torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))
930        )
931        self.assertEqual(
932            torch.zeros((5, 6), device=device),
933            fn(torch.mm, (5, 0), (0, 6), test_out=True),
934        )
935
936        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
937        self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape)
938        t = torch.randn((5, 6), device=device)
939        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
940        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
941
942        # mv, addmv
943        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
944        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
945        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
946        self.assertEqual(
947            torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)
948        )
949
950        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
951        t = torch.randn((3,), device=device)
952        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
953        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
954
955        # bmm, baddbmm
956        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
957        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
958        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
959        self.assertEqual(
960            torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))
961        )
962        self.assertEqual(
963            torch.zeros((3, 5, 6), device=device),
964            fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True),
965        )
966
967        self.assertEqual(
968            (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape
969        )
970        self.assertEqual(
971            (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape
972        )
973        self.assertEqual(
974            (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape
975        )
976        self.assertEqual(
977            (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape
978        )
979        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
980        self.assertEqual(
981            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)
982        )  # Issue #33467
983        self.assertEqual(
984            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)
985        )  # Issue #33467
986
987        # addbmm
988        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
989        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
990        t = torch.randn((5, 6), device=device)
991        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
992        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
993
994        # matmul
995        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,)))
996        self.assertEqual(
997            torch.tensor(0.0, device=device),
998            fn(torch.matmul, (0,), (0,), test_out=True),
999        )
1000        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
1001        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
1002        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
1003        self.assertEqual(
1004            torch.zeros((5, 3, 4), device=device),
1005            fn(torch.matmul, (5, 3, 0), (5, 0, 4)),
1006        )
1007        self.assertEqual(
1008            torch.zeros((5, 3, 4), device=device),
1009            fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True),
1010        )
1011
1012        # dot
1013        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,)))
1014        self.assertEqual(
1015            torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
1016        )
1017
1018    def test_large_bmm_backward(self, device):
1019        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
1020        B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
1021        G = torch.randn([1024, 2, 65536], device=device)
1022
1023        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
1024        (A @ B).backward(G)
1025
1026    def test_large_bmm_mm_backward(self, device):
1027        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
1028        B = torch.randn([1024, 65536], device=device, requires_grad=True)
1029        G = torch.randn([1024, 2, 65536], device=device)
1030
1031        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
1032        (A @ B).backward(G)
1033
1034    def check_single_matmul(self, x, y):
1035        def assertEqual(answer, expected):
1036            if x.dtype.is_floating_point or x.dtype.is_complex:
1037                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
1038                self.assertEqual(
1039                    answer,
1040                    expected,
1041                    msg=f"{x.shape} x {y.shape} = {answer.shape}",
1042                    atol=k * 5e-5,
1043                    rtol=1e-4,
1044                )
1045            else:
1046                self.assertEqual(
1047                    answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}"
1048                )
1049
1050        # test x @ y
1051        expected = np.matmul(x.cpu(), y.cpu())
1052        ans = torch.matmul(x, y)
1053        self.assertTrue(ans.is_contiguous())
1054        assertEqual(ans, expected)
1055
1056        # test out
1057        out = torch.empty_like(ans)
1058        ans = torch.matmul(x, y, out=out)
1059        self.assertIs(ans, out)
1060        self.assertTrue(ans.is_contiguous())
1061        assertEqual(ans, expected)
1062
1063    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
1064        """
1065        Generates sequences of tuples (x, y) of with size(x) = x_dim and
1066        size(y) <= y_dim that are compatible wrt. matmul
1067        """
1068        assert x_dim >= 1
1069        assert y_dim >= 2
1070        x = x_dim
1071        for y in range(1, y_dim + 1):
1072            for batch, mn in product(
1073                product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
1074                product(range(matrix_size), repeat=min(y, 2)),
1075            ):
1076                if x == 1:
1077                    size_x = mn[:1]
1078                    size_y = batch + mn
1079                    yield size_x, size_y
1080                else:
1081                    for k in range(matrix_size):
1082                        size_x = (k,) + mn[:1]
1083                        if x > 2:
1084                            size_x = batch[-(x - 2) :] + size_x
1085                        size_y = mn
1086                        if y > 2:
1087                            size_y = batch[-(y - 2) :] + size_y
1088                        yield size_x, size_y
1089
1090    @dtypes(torch.float)
1091    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
1092        make_arg = partial(make_tensor, device=device, dtype=dtype)
1093
1094        for (size_x, size_y), nctg_x, nctg_y in product(
1095            self.gen_sizes_matmul(1), (True, False), (True, False)
1096        ):
1097            x = make_arg(size_x, noncontiguous=nctg_x)
1098            y = make_arg(size_y, noncontiguous=nctg_y)
1099            self.check_single_matmul(x, y)
1100
1101    @dtypes(torch.float)
1102    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
1103        make_arg = partial(make_tensor, device=device, dtype=dtype)
1104
1105        for (size_x, size_y), nctg_x, nctg_y in product(
1106            self.gen_sizes_matmul(2), (True, False), (True, False)
1107        ):
1108            x = make_arg(size_x, noncontiguous=nctg_x)
1109            y = make_arg(size_y, noncontiguous=nctg_y)
1110            self.check_single_matmul(x, y)
1111
1112    @dtypes(torch.float)
1113    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
1114        make_arg = partial(make_tensor, device=device, dtype=dtype)
1115
1116        for (size_x, size_y), nctg_x, nctg_y in product(
1117            self.gen_sizes_matmul(3), (True, False), (True, False)
1118        ):
1119            x = make_arg(size_x, noncontiguous=nctg_x)
1120            y = make_arg(size_y, noncontiguous=nctg_y)
1121            self.check_single_matmul(x, y)
1122
1123    @dtypes(torch.float)
1124    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
1125        a = torch.empty(
1126            (256, 512), device=device, dtype=dtype, requires_grad=True
1127        ).unsqueeze(0)
1128        b = torch.empty(
1129            (4, 128, 512), device=device, dtype=dtype, requires_grad=True
1130        ).transpose(-1, -2)
1131        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
1132
1133        torch.matmul(a.detach(), b.detach(), out=c)
1134
1135        with self.assertRaisesRegex(
1136            RuntimeError,
1137            "functions with out=... arguments don't support automatic differentiation",
1138        ):
1139            torch.matmul(a, b, out=c)
1140
1141        with torch.no_grad():
1142            torch.matmul(a, b, out=c)
1143
1144
1145instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
1146
1147if __name__ == "__main__":
1148    run_tests()
1149