xref: /aosp_15_r20/external/pytorch/test/xpu/test_gemm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport random
6*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
7*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport numpy as np
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
14*da0073e9SAndroid Build Coastguard Worker    dtypes,
15*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
16*da0073e9SAndroid Build Coastguard Worker    precisionOverride,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import iter_indices, run_tests, TestCase
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerclass TestBasicGEMM(TestCase):
22*da0073e9SAndroid Build Coastguard Worker    def _test_addmm_addmv(
23*da0073e9SAndroid Build Coastguard Worker        self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
24*da0073e9SAndroid Build Coastguard Worker    ):
25*da0073e9SAndroid Build Coastguard Worker        dtype = t.dtype
26*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype
27*da0073e9SAndroid Build Coastguard Worker        if dtype in {torch.bfloat16, torch.half}:
28*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = torch.float
29*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
30*da0073e9SAndroid Build Coastguard Worker            alpha = 0.9 + 0.3j if alpha is None else alpha
31*da0073e9SAndroid Build Coastguard Worker            beta = 0.5 + 0.6j if beta is None else beta
32*da0073e9SAndroid Build Coastguard Worker        else:
33*da0073e9SAndroid Build Coastguard Worker            alpha = 1.2 if alpha is None else alpha
34*da0073e9SAndroid Build Coastguard Worker            beta = 0.8 if beta is None else beta
35*da0073e9SAndroid Build Coastguard Worker        if activation == "gelu":
36*da0073e9SAndroid Build Coastguard Worker            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
37*da0073e9SAndroid Build Coastguard Worker        else:
38*da0073e9SAndroid Build Coastguard Worker            res1 = f(t, m, v, alpha=alpha, beta=beta)
39*da0073e9SAndroid Build Coastguard Worker        res2 = torch.full_like(res1, math.nan)
40*da0073e9SAndroid Build Coastguard Worker        if transpose_out:
41*da0073e9SAndroid Build Coastguard Worker            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
42*da0073e9SAndroid Build Coastguard Worker        if activation == "gelu":
43*da0073e9SAndroid Build Coastguard Worker            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
44*da0073e9SAndroid Build Coastguard Worker        else:
45*da0073e9SAndroid Build Coastguard Worker            f(t, m, v, alpha=alpha, beta=beta, out=res2)
46*da0073e9SAndroid Build Coastguard Worker        m.to(numpy_dtype).cpu().numpy()
47*da0073e9SAndroid Build Coastguard Worker        v.to(numpy_dtype).cpu().numpy()
48*da0073e9SAndroid Build Coastguard Worker        res3 = alpha * (
49*da0073e9SAndroid Build Coastguard Worker            m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()
50*da0073e9SAndroid Build Coastguard Worker        )
51*da0073e9SAndroid Build Coastguard Worker        if beta != 0:
52*da0073e9SAndroid Build Coastguard Worker            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
53*da0073e9SAndroid Build Coastguard Worker        if activation == "relu":
54*da0073e9SAndroid Build Coastguard Worker            res3 = res3 * (res3 > 0)
55*da0073e9SAndroid Build Coastguard Worker        elif activation == "gelu":
56*da0073e9SAndroid Build Coastguard Worker            res3_t = torch.from_numpy(res3).to(dtype)
57*da0073e9SAndroid Build Coastguard Worker            approximate = "tanh" if t.is_cuda else "none"
58*da0073e9SAndroid Build Coastguard Worker            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
59*da0073e9SAndroid Build Coastguard Worker            res3 = res3_t.to(numpy_dtype).cpu().numpy()
60*da0073e9SAndroid Build Coastguard Worker        else:
61*da0073e9SAndroid Build Coastguard Worker            assert activation is None, f"unsupported activation {activation}"
62*da0073e9SAndroid Build Coastguard Worker        res3 = torch.from_numpy(res3).to(dtype)
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
64*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res3)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def _test_addmm_impl(self, func, activation, device, dtype):
67*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
68*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
69*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
70*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
73*da0073e9SAndroid Build Coastguard Worker        V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device)
74*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        # Test 0-strided
77*da0073e9SAndroid Build Coastguard Worker        M = (
78*da0073e9SAndroid Build Coastguard Worker            torch.randn(10, 1, device="cpu", dtype=torch.float32)
79*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
80*da0073e9SAndroid Build Coastguard Worker            .expand(10, 25)
81*da0073e9SAndroid Build Coastguard Worker            .to(device)
82*da0073e9SAndroid Build Coastguard Worker        )
83*da0073e9SAndroid Build Coastguard Worker        m1 = (
84*da0073e9SAndroid Build Coastguard Worker            torch.randn(10, 1, device="cpu", dtype=torch.float32)
85*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
86*da0073e9SAndroid Build Coastguard Worker            .expand(10, 50)
87*da0073e9SAndroid Build Coastguard Worker            .to(device)
88*da0073e9SAndroid Build Coastguard Worker        )
89*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
90*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, M=nan
93*da0073e9SAndroid Build Coastguard Worker        M = (
94*da0073e9SAndroid Build Coastguard Worker            torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32)
95*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
96*da0073e9SAndroid Build Coastguard Worker            .to(device)
97*da0073e9SAndroid Build Coastguard Worker        )
98*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
99*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
100*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        # Test transpose
103*da0073e9SAndroid Build Coastguard Worker        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker            def maybe_transpose(cond, m):
106*da0073e9SAndroid Build Coastguard Worker                if not cond:
107*da0073e9SAndroid Build Coastguard Worker                    return m
108*da0073e9SAndroid Build Coastguard Worker                return m.t().clone(memory_format=torch.contiguous_format).t()
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
111*da0073e9SAndroid Build Coastguard Worker            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
112*da0073e9SAndroid Build Coastguard Worker            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
113*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(
114*da0073e9SAndroid Build Coastguard Worker                func, M, m1, m2, transpose_out=t4, activation=activation
115*da0073e9SAndroid Build Coastguard Worker            )
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker            if t1:
118*da0073e9SAndroid Build Coastguard Worker                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
119*da0073e9SAndroid Build Coastguard Worker                self._test_addmm_addmv(
120*da0073e9SAndroid Build Coastguard Worker                    func,
121*da0073e9SAndroid Build Coastguard Worker                    V,
122*da0073e9SAndroid Build Coastguard Worker                    m1,
123*da0073e9SAndroid Build Coastguard Worker                    m2,
124*da0073e9SAndroid Build Coastguard Worker                    beta=1,
125*da0073e9SAndroid Build Coastguard Worker                    transpose_out=t4,
126*da0073e9SAndroid Build Coastguard Worker                    activation=activation,
127*da0073e9SAndroid Build Coastguard Worker                )
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
130*da0073e9SAndroid Build Coastguard Worker        {
131*da0073e9SAndroid Build Coastguard Worker            torch.float: 1e-4,
132*da0073e9SAndroid Build Coastguard Worker            torch.half: 1e-1,
133*da0073e9SAndroid Build Coastguard Worker        }
134*da0073e9SAndroid Build Coastguard Worker    )
135*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.half)
136*da0073e9SAndroid Build Coastguard Worker    def test_addmm(self, device, dtype):
137*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch.addmm, None, device, dtype)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
140*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half, torch.float)
141*da0073e9SAndroid Build Coastguard Worker    def test_addmv(self, device, dtype):
142*da0073e9SAndroid Build Coastguard Worker        # have to use torch.randn(...).to(bfloat16) instead of
143*da0073e9SAndroid Build Coastguard Worker        # torch.randn(..., dtype=bfloat16). randn does not support
144*da0073e9SAndroid Build Coastguard Worker        # bfloat16 yet.
145*da0073e9SAndroid Build Coastguard Worker        # "*0.2" to reduce errors for low precision
146*da0073e9SAndroid Build Coastguard Worker        ts = [
147*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(50, device=device).to(dtype),
148*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
149*da0073e9SAndroid Build Coastguard Worker        ]
150*da0073e9SAndroid Build Coastguard Worker        vs = [
151*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn(100, device=device).to(dtype),
152*da0073e9SAndroid Build Coastguard Worker            0.2
153*da0073e9SAndroid Build Coastguard Worker            * torch.ones(1, device=device)
154*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
155*da0073e9SAndroid Build Coastguard Worker            .expand(100),  # to reduce errors for low precision
156*da0073e9SAndroid Build Coastguard Worker        ]
157*da0073e9SAndroid Build Coastguard Worker        ms = [
158*da0073e9SAndroid Build Coastguard Worker            # 0d
159*da0073e9SAndroid Build Coastguard Worker            0.2
160*da0073e9SAndroid Build Coastguard Worker            * torch.ones((), device=device)
161*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
162*da0073e9SAndroid Build Coastguard Worker            .expand(50, 100),  # to reduce errors for low precision
163*da0073e9SAndroid Build Coastguard Worker            # 1d
164*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
165*da0073e9SAndroid Build Coastguard Worker            # this initialization reduces errors for low precision for broadcasted matrices
166*da0073e9SAndroid Build Coastguard Worker            # by making sure that intermediate and result values are exactly representable
167*da0073e9SAndroid Build Coastguard Worker            # in low precision type
168*da0073e9SAndroid Build Coastguard Worker            0.2
169*da0073e9SAndroid Build Coastguard Worker            * torch.randint(3, (50, 1), dtype=torch.float, device=device)
170*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
171*da0073e9SAndroid Build Coastguard Worker            .expand(50, 100),
172*da0073e9SAndroid Build Coastguard Worker            # 2d
173*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((50, 100), device=device).to(dtype),
174*da0073e9SAndroid Build Coastguard Worker            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
175*da0073e9SAndroid Build Coastguard Worker        ]
176*da0073e9SAndroid Build Coastguard Worker        for m, v, t in itertools.product(ms, vs, ts):
177*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, t, m, v)
178*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, t=nan
179*da0073e9SAndroid Build Coastguard Worker        t = torch.full((50,), math.nan, device=device).to(dtype)
180*da0073e9SAndroid Build Coastguard Worker        for m, v in itertools.product(ms, vs):
181*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    @dtypes(
184*da0073e9SAndroid Build Coastguard Worker        torch.half,
185*da0073e9SAndroid Build Coastguard Worker        torch.float32,
186*da0073e9SAndroid Build Coastguard Worker    )
187*da0073e9SAndroid Build Coastguard Worker    def test_mm(self, device, dtype):
188*da0073e9SAndroid Build Coastguard Worker        def _test_mm(n, m, p, dtype, genf):
189*da0073e9SAndroid Build Coastguard Worker            # helper function
190*da0073e9SAndroid Build Coastguard Worker            def matrixmultiply(mat1, mat2):
191*da0073e9SAndroid Build Coastguard Worker                n = mat1.size(0)
192*da0073e9SAndroid Build Coastguard Worker                m = mat1.size(1)
193*da0073e9SAndroid Build Coastguard Worker                p = mat2.size(1)
194*da0073e9SAndroid Build Coastguard Worker                dtype_ = torch.float if dtype == torch.half else dtype
195*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.half:
196*da0073e9SAndroid Build Coastguard Worker                    mat1 = mat1.float()
197*da0073e9SAndroid Build Coastguard Worker                    mat2 = mat2.float()
198*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros(n, p, dtype=dtype_, device=device)
199*da0073e9SAndroid Build Coastguard Worker                for i, j in iter_indices(res):
200*da0073e9SAndroid Build Coastguard Worker                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
201*da0073e9SAndroid Build Coastguard Worker                return res.half() if dtype == torch.half else res
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker            # contiguous case
204*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
205*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
206*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
209*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 1
212*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
213*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
214*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 2
220*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
221*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
222*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 3
228*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
229*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
230*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            # test with zero stride
236*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
237*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, 1).expand(m, p)
238*da0073e9SAndroid Build Coastguard Worker            res = torch.mm(mat1, mat2)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
241*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker            # explicitly exercise the _out variant in torch.mm().
244*da0073e9SAndroid Build Coastguard Worker            # contiguous case
245*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(n, m)
246*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(m, p)
247*da0073e9SAndroid Build Coastguard Worker            res = genf(n, p)
248*da0073e9SAndroid Build Coastguard Worker            torch.mm(mat1, mat2, out=res)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
251*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker            # explicitly exercise the _out variant in torch.mm().
254*da0073e9SAndroid Build Coastguard Worker            # non contiguous case 3
255*da0073e9SAndroid Build Coastguard Worker            mat1 = genf(m, n).t()
256*da0073e9SAndroid Build Coastguard Worker            mat2 = genf(p, m).t()
257*da0073e9SAndroid Build Coastguard Worker            res = genf(n, p)
258*da0073e9SAndroid Build Coastguard Worker            torch.mm(mat1, mat2, out=res)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker            res2 = matrixmultiply(mat1, mat2)
261*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, res2)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        def genf_int(x, y):
264*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        def genf_bfloat(x, y):
267*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        def genf_float(x, y):
270*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=dtype, device=device)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        def genf_Half(x, y):
273*da0073e9SAndroid Build Coastguard Worker            return torch.randn(x, y, dtype=dtype, device=device)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
276*da0073e9SAndroid Build Coastguard Worker            if (dtype == torch.int32) or (dtype == torch.int64):
277*da0073e9SAndroid Build Coastguard Worker                genf = genf_int
278*da0073e9SAndroid Build Coastguard Worker            elif dtype == torch.bfloat16:
279*da0073e9SAndroid Build Coastguard Worker                genf = genf_bfloat
280*da0073e9SAndroid Build Coastguard Worker            elif dtype == torch.half:
281*da0073e9SAndroid Build Coastguard Worker                genf = genf_Half
282*da0073e9SAndroid Build Coastguard Worker            else:
283*da0073e9SAndroid Build Coastguard Worker                genf = genf_float
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker            _test_mm(n, m, p, dtype, genf)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
288*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16, torch.half)
289*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self, device, dtype):
290*da0073e9SAndroid Build Coastguard Worker        batch_sizes = [1, 10]
291*da0073e9SAndroid Build Coastguard Worker        M, N, O = 23, 15, 12
292*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
295*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
296*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        def generate_inputs(num_batches):
299*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
300*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2 in itertools.product(
301*da0073e9SAndroid Build Coastguard Worker                itertools.permutations((0, 1, 2)), repeat=2
302*da0073e9SAndroid Build Coastguard Worker            ):
303*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(
304*da0073e9SAndroid Build Coastguard Worker                    (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1
305*da0073e9SAndroid Build Coastguard Worker                )
306*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(
307*da0073e9SAndroid Build Coastguard Worker                    (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1
308*da0073e9SAndroid Build Coastguard Worker                )
309*da0073e9SAndroid Build Coastguard Worker                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
310*da0073e9SAndroid Build Coastguard Worker                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
311*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
312*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
313*da0073e9SAndroid Build Coastguard Worker            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
314*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
315*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
316*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(
317*da0073e9SAndroid Build Coastguard Worker                    shape1, dtype=dtype, device=device, low=-0.1, high=0.1
318*da0073e9SAndroid Build Coastguard Worker                ).expand(num_batches, M, N)
319*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(
320*da0073e9SAndroid Build Coastguard Worker                    shape2, dtype=dtype, device=device, low=-0.1, high=0.1
321*da0073e9SAndroid Build Coastguard Worker                ).expand(num_batches, N, O)
322*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
323*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
324*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
325*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
326*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
327*da0073e9SAndroid Build Coastguard Worker                b1 = torch.randn(shape1, dtype=dtype, device=device)
328*da0073e9SAndroid Build Coastguard Worker                b2 = torch.randn(shape2, dtype=dtype, device=device)
329*da0073e9SAndroid Build Coastguard Worker                yield b1, b2
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        for num_batches in batch_sizes:
332*da0073e9SAndroid Build Coastguard Worker            for (b1, b2), perm3 in itertools.product(
333*da0073e9SAndroid Build Coastguard Worker                generate_inputs(num_batches), itertools.permutations((0, 1, 2))
334*da0073e9SAndroid Build Coastguard Worker            ):
335*da0073e9SAndroid Build Coastguard Worker                res1 = torch.bmm(b1, b2)
336*da0073e9SAndroid Build Coastguard Worker                res2 = (
337*da0073e9SAndroid Build Coastguard Worker                    torch.full(
338*da0073e9SAndroid Build Coastguard Worker                        (num_batches, M, O), math.nan, dtype=dtype, device=device
339*da0073e9SAndroid Build Coastguard Worker                    )
340*da0073e9SAndroid Build Coastguard Worker                    .permute(perm3)
341*da0073e9SAndroid Build Coastguard Worker                    .contiguous()
342*da0073e9SAndroid Build Coastguard Worker                    .permute(invert_perm(perm3))
343*da0073e9SAndroid Build Coastguard Worker                )
344*da0073e9SAndroid Build Coastguard Worker                torch.bmm(b1, b2, out=res2)
345*da0073e9SAndroid Build Coastguard Worker                expect = torch.from_numpy(
346*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
347*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype)
348*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expect, res1)
349*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expect, res2)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker                if self.device_type == "cuda":
352*da0073e9SAndroid Build Coastguard Worker                    # check that mixed arguments are rejected
353*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
354*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
355*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(
356*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())
357*da0073e9SAndroid Build Coastguard Worker                    )
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
360*da0073e9SAndroid Build Coastguard Worker        getattr(out_tensor, func + "_")(b1, b2)
361*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref)
362*da0073e9SAndroid Build Coastguard Worker        res3 = out_tensor.clone()
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
365*da0073e9SAndroid Build Coastguard Worker            UserWarning, f"This overload of {func}_ is deprecated"
366*da0073e9SAndroid Build Coastguard Worker        ):
367*da0073e9SAndroid Build Coastguard Worker            getattr(out_tensor, func + "_")(1, b1, b2)
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref * 2),
369*da0073e9SAndroid Build Coastguard Worker        getattr(res3, func + "_")(b1, b2, beta=1)
370*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, res3)
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
373*da0073e9SAndroid Build Coastguard Worker            UserWarning, f"This overload of {func}_ is deprecated"
374*da0073e9SAndroid Build Coastguard Worker        ):
375*da0073e9SAndroid Build Coastguard Worker            getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2)
376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, ref * 2.5)
377*da0073e9SAndroid Build Coastguard Worker        getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5)
378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_tensor, res3)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
381*da0073e9SAndroid Build Coastguard Worker            UserWarning, f"This overload of {func} is deprecated"
382*da0073e9SAndroid Build Coastguard Worker        ):
383*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5)
386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res4, ref * 3),
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker        nan = torch.full_like(out_tensor, math.nan)
389*da0073e9SAndroid Build Coastguard Worker        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res5, ref)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        if b1.is_complex():
393*da0073e9SAndroid Build Coastguard Worker            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j)
394*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref)
395*da0073e9SAndroid Build Coastguard Worker        else:
396*da0073e9SAndroid Build Coastguard Worker            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5)
397*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        res7 = torch.full_like(out_tensor, math.nan)
400*da0073e9SAndroid Build Coastguard Worker        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res7, ref)
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
404*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16, torch.half)
405*da0073e9SAndroid Build Coastguard Worker    def test_addbmm(self, device, dtype):
406*da0073e9SAndroid Build Coastguard Worker        num_batches = 2
407*da0073e9SAndroid Build Coastguard Worker        M, N, O = 16, 17, 18
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        is_supported = True
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        if not is_supported:
412*da0073e9SAndroid Build Coastguard Worker            b1 = make_tensor(
413*da0073e9SAndroid Build Coastguard Worker                (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
414*da0073e9SAndroid Build Coastguard Worker            )
415*da0073e9SAndroid Build Coastguard Worker            b2 = make_tensor(
416*da0073e9SAndroid Build Coastguard Worker                (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
417*da0073e9SAndroid Build Coastguard Worker            )
418*da0073e9SAndroid Build Coastguard Worker            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
419*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
420*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
421*da0073e9SAndroid Build Coastguard Worker                "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
422*da0073e9SAndroid Build Coastguard Worker                lambda: torch.addbmm(t, b1, b2),
423*da0073e9SAndroid Build Coastguard Worker            )
424*da0073e9SAndroid Build Coastguard Worker            return
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
427*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
428*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        def generate_tensor():
431*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
432*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
433*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2 in itertools.product(
434*da0073e9SAndroid Build Coastguard Worker                itertools.permutations((0, 1, 2)), repeat=2
435*da0073e9SAndroid Build Coastguard Worker            ):
436*da0073e9SAndroid Build Coastguard Worker                for perm3 in itertools.permutations((0, 1)):
437*da0073e9SAndroid Build Coastguard Worker                    b1 = (
438*da0073e9SAndroid Build Coastguard Worker                        make_tensor(
439*da0073e9SAndroid Build Coastguard Worker                            (num_batches, M, N),
440*da0073e9SAndroid Build Coastguard Worker                            dtype=dtype,
441*da0073e9SAndroid Build Coastguard Worker                            device=device,
442*da0073e9SAndroid Build Coastguard Worker                            low=-1,
443*da0073e9SAndroid Build Coastguard Worker                            high=1,
444*da0073e9SAndroid Build Coastguard Worker                        )
445*da0073e9SAndroid Build Coastguard Worker                        * 0.1
446*da0073e9SAndroid Build Coastguard Worker                    )
447*da0073e9SAndroid Build Coastguard Worker                    b2 = (
448*da0073e9SAndroid Build Coastguard Worker                        make_tensor(
449*da0073e9SAndroid Build Coastguard Worker                            (num_batches, N, O),
450*da0073e9SAndroid Build Coastguard Worker                            dtype=dtype,
451*da0073e9SAndroid Build Coastguard Worker                            device=device,
452*da0073e9SAndroid Build Coastguard Worker                            low=-1,
453*da0073e9SAndroid Build Coastguard Worker                            high=1,
454*da0073e9SAndroid Build Coastguard Worker                        )
455*da0073e9SAndroid Build Coastguard Worker                        * 0.1
456*da0073e9SAndroid Build Coastguard Worker                    )
457*da0073e9SAndroid Build Coastguard Worker                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
458*da0073e9SAndroid Build Coastguard Worker                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
459*da0073e9SAndroid Build Coastguard Worker                    ref = (
460*da0073e9SAndroid Build Coastguard Worker                        torch.from_numpy(
461*da0073e9SAndroid Build Coastguard Worker                            b1.to(numpy_dtype).cpu().numpy()
462*da0073e9SAndroid Build Coastguard Worker                            @ b2.to(numpy_dtype).cpu().numpy()
463*da0073e9SAndroid Build Coastguard Worker                        )
464*da0073e9SAndroid Build Coastguard Worker                        .to(device=device, dtype=dtype)
465*da0073e9SAndroid Build Coastguard Worker                        .sum(0)
466*da0073e9SAndroid Build Coastguard Worker                    )
467*da0073e9SAndroid Build Coastguard Worker                    out_tensor = (
468*da0073e9SAndroid Build Coastguard Worker                        torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
469*da0073e9SAndroid Build Coastguard Worker                    )
470*da0073e9SAndroid Build Coastguard Worker                    yield b1, b2, ref, out_tensor
471*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
472*da0073e9SAndroid Build Coastguard Worker            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
473*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
474*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
475*da0073e9SAndroid Build Coastguard Worker                b1 = (
476*da0073e9SAndroid Build Coastguard Worker                    make_tensor(
477*da0073e9SAndroid Build Coastguard Worker                        shape1, dtype=dtype, device=device, low=-1, high=1
478*da0073e9SAndroid Build Coastguard Worker                    ).expand(num_batches, M, N)
479*da0073e9SAndroid Build Coastguard Worker                    * 0.1
480*da0073e9SAndroid Build Coastguard Worker                )
481*da0073e9SAndroid Build Coastguard Worker                b2 = (
482*da0073e9SAndroid Build Coastguard Worker                    make_tensor(
483*da0073e9SAndroid Build Coastguard Worker                        shape2, dtype=dtype, device=device, low=-1, high=1
484*da0073e9SAndroid Build Coastguard Worker                    ).expand(num_batches, N, O)
485*da0073e9SAndroid Build Coastguard Worker                    * 0.1
486*da0073e9SAndroid Build Coastguard Worker                )
487*da0073e9SAndroid Build Coastguard Worker                ref = (
488*da0073e9SAndroid Build Coastguard Worker                    torch.from_numpy(
489*da0073e9SAndroid Build Coastguard Worker                        b1.to(numpy_dtype).cpu().numpy()
490*da0073e9SAndroid Build Coastguard Worker                        @ b2.to(numpy_dtype).cpu().numpy()
491*da0073e9SAndroid Build Coastguard Worker                    )
492*da0073e9SAndroid Build Coastguard Worker                    .to(device=device, dtype=dtype)
493*da0073e9SAndroid Build Coastguard Worker                    .sum(0)
494*da0073e9SAndroid Build Coastguard Worker                )
495*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
496*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
497*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
498*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
499*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
500*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
501*da0073e9SAndroid Build Coastguard Worker                b1 = (
502*da0073e9SAndroid Build Coastguard Worker                    make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1)
503*da0073e9SAndroid Build Coastguard Worker                    * 0.1
504*da0073e9SAndroid Build Coastguard Worker                )
505*da0073e9SAndroid Build Coastguard Worker                b2 = (
506*da0073e9SAndroid Build Coastguard Worker                    make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1)
507*da0073e9SAndroid Build Coastguard Worker                    * 0.1
508*da0073e9SAndroid Build Coastguard Worker                )
509*da0073e9SAndroid Build Coastguard Worker                ref = (
510*da0073e9SAndroid Build Coastguard Worker                    torch.from_numpy(
511*da0073e9SAndroid Build Coastguard Worker                        b1.to(numpy_dtype).cpu().numpy()
512*da0073e9SAndroid Build Coastguard Worker                        @ b2.to(numpy_dtype).cpu().numpy()
513*da0073e9SAndroid Build Coastguard Worker                    )
514*da0073e9SAndroid Build Coastguard Worker                    .to(device=device, dtype=dtype)
515*da0073e9SAndroid Build Coastguard Worker                    .sum(0)
516*da0073e9SAndroid Build Coastguard Worker                )
517*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
518*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        for b1, b2, ref, out_tensor in generate_tensor():
521*da0073e9SAndroid Build Coastguard Worker            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
524*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16, torch.half)
525*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm(self, device, dtype):
526*da0073e9SAndroid Build Coastguard Worker        num_batches = 10
527*da0073e9SAndroid Build Coastguard Worker        M, N, O = 12, 8, 50
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker        def invert_perm(p):
530*da0073e9SAndroid Build Coastguard Worker            d = {x: i for i, x in enumerate(p)}
531*da0073e9SAndroid Build Coastguard Worker            return (d[0], d[1], d[2])
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        def generate_tensor():
534*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = (
535*da0073e9SAndroid Build Coastguard Worker                dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
536*da0073e9SAndroid Build Coastguard Worker            )
537*da0073e9SAndroid Build Coastguard Worker            # transposed tensors
538*da0073e9SAndroid Build Coastguard Worker            for perm1, perm2, perm3 in itertools.product(
539*da0073e9SAndroid Build Coastguard Worker                itertools.permutations((0, 1, 2)), repeat=3
540*da0073e9SAndroid Build Coastguard Worker            ):
541*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(
542*da0073e9SAndroid Build Coastguard Worker                    (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
543*da0073e9SAndroid Build Coastguard Worker                )
544*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(
545*da0073e9SAndroid Build Coastguard Worker                    (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
546*da0073e9SAndroid Build Coastguard Worker                )
547*da0073e9SAndroid Build Coastguard Worker                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
548*da0073e9SAndroid Build Coastguard Worker                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
549*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
550*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
551*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype)
552*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
553*da0073e9SAndroid Build Coastguard Worker                out_tensor = (
554*da0073e9SAndroid Build Coastguard Worker                    out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
555*da0073e9SAndroid Build Coastguard Worker                )
556*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
557*da0073e9SAndroid Build Coastguard Worker            # broadcasting tensors
558*da0073e9SAndroid Build Coastguard Worker            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
559*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
560*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
561*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(
562*da0073e9SAndroid Build Coastguard Worker                    shape1, dtype=dtype, device=device, low=-1, high=1
563*da0073e9SAndroid Build Coastguard Worker                ).expand(num_batches, M, N)
564*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(
565*da0073e9SAndroid Build Coastguard Worker                    shape2, dtype=dtype, device=device, low=-1, high=1
566*da0073e9SAndroid Build Coastguard Worker                ).expand(num_batches, N, O)
567*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
568*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
569*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype)
570*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
571*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
572*da0073e9SAndroid Build Coastguard Worker            # zero-sized tensors
573*da0073e9SAndroid Build Coastguard Worker            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
574*da0073e9SAndroid Build Coastguard Worker                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
575*da0073e9SAndroid Build Coastguard Worker                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
576*da0073e9SAndroid Build Coastguard Worker                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
577*da0073e9SAndroid Build Coastguard Worker                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
578*da0073e9SAndroid Build Coastguard Worker                ref = torch.from_numpy(
579*da0073e9SAndroid Build Coastguard Worker                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
580*da0073e9SAndroid Build Coastguard Worker                ).to(device=device, dtype=dtype)
581*da0073e9SAndroid Build Coastguard Worker                out_tensor = torch.zeros_like(ref)
582*da0073e9SAndroid Build Coastguard Worker                yield b1, b2, ref, out_tensor
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        for b1, b2, ref, out_tensor in generate_tensor():
585*da0073e9SAndroid Build Coastguard Worker            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    def test_tensordot(self, device):
588*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(60.0, device=device).reshape(3, 4, 5)
589*da0073e9SAndroid Build Coastguard Worker        b = torch.arange(24.0, device=device).reshape(4, 3, 2)
590*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
591*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(
592*da0073e9SAndroid Build Coastguard Worker            np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1]))
593*da0073e9SAndroid Build Coastguard Worker        )
594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        cout = torch.zeros((5, 2), device=device)
597*da0073e9SAndroid Build Coastguard Worker        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
598*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cout)
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, 5, device=device)
601*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 5, 6, 7, device=device)
602*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b, dims=2).cpu()
603*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2))
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
606*da0073e9SAndroid Build Coastguard Worker            torch.tensordot(a, b, dims=-1)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
609*da0073e9SAndroid Build Coastguard Worker        c = torch.tensordot(a, b).cpu()
610*da0073e9SAndroid Build Coastguard Worker        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cn)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0)
614*da0073e9SAndroid Build Coastguard Worker        an = torch.from_numpy(
615*da0073e9SAndroid Build Coastguard Worker            np.tensordot(
616*da0073e9SAndroid Build Coastguard Worker                np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0
617*da0073e9SAndroid Build Coastguard Worker            )
618*da0073e9SAndroid Build Coastguard Worker        )
619*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, an)
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
622*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float32: 1e-4})
623*da0073e9SAndroid Build Coastguard Worker    def test_1_sized_with_0_strided(self, device, dtype):
624*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
625*da0073e9SAndroid Build Coastguard Worker        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
626*da0073e9SAndroid Build Coastguard Worker        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
627*da0073e9SAndroid Build Coastguard Worker        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
628*da0073e9SAndroid Build Coastguard Worker        res = torch.bmm(a_strided, b_strided)
629*da0073e9SAndroid Build Coastguard Worker        expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(
630*da0073e9SAndroid Build Coastguard Worker            device=device, dtype=dtype
631*da0073e9SAndroid Build Coastguard Worker        )
632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect, res)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker    def _select_broadcastable_dims(self, dims_full=None):
635*da0073e9SAndroid Build Coastguard Worker        # select full dimensionality
636*da0073e9SAndroid Build Coastguard Worker        if dims_full is None:
637*da0073e9SAndroid Build Coastguard Worker            dims_full = []
638*da0073e9SAndroid Build Coastguard Worker            ndims = random.randint(1, 4)
639*da0073e9SAndroid Build Coastguard Worker            dims_full = [random.randint(1, 8) for _ in range(ndims)]
640*da0073e9SAndroid Build Coastguard Worker        else:
641*da0073e9SAndroid Build Coastguard Worker            ndims = len(dims_full)
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        # select actual dimensions for ops:
644*da0073e9SAndroid Build Coastguard Worker        # larger: full ndims, individual sizes may be reduced
645*da0073e9SAndroid Build Coastguard Worker        # smaller: possibly reduced ndims, sizes may be reduced
646*da0073e9SAndroid Build Coastguard Worker        smaller_ndims = random.randint(1, ndims)
647*da0073e9SAndroid Build Coastguard Worker        dims_small = []
648*da0073e9SAndroid Build Coastguard Worker        dims_large = []
649*da0073e9SAndroid Build Coastguard Worker        for i in range(ndims - 1, -1, -1):
650*da0073e9SAndroid Build Coastguard Worker            j = random.randint(1, 3)
651*da0073e9SAndroid Build Coastguard Worker            if j == 1:  # no reduced singleton dimension
652*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
653*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
654*da0073e9SAndroid Build Coastguard Worker            elif j == 2:  # larger may have reduced singleton dimension
655*da0073e9SAndroid Build Coastguard Worker                ds = dims_full[i]
656*da0073e9SAndroid Build Coastguard Worker                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
657*da0073e9SAndroid Build Coastguard Worker            elif j == 3:  # smaller may have reduced singleton dimension
658*da0073e9SAndroid Build Coastguard Worker                ds = 1
659*da0073e9SAndroid Build Coastguard Worker                dl = dims_full[i]
660*da0073e9SAndroid Build Coastguard Worker            dims_large = [dl] + dims_large
661*da0073e9SAndroid Build Coastguard Worker            if len(dims_small) < smaller_ndims:
662*da0073e9SAndroid Build Coastguard Worker                dims_small = [ds] + dims_small
663*da0073e9SAndroid Build Coastguard Worker        return (dims_small, dims_large, dims_full)
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_fused_matmul(self, device):
666*da0073e9SAndroid Build Coastguard Worker        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker        for fn in fns:
669*da0073e9SAndroid Build Coastguard Worker            batch_dim = random.randint(1, 8)
670*da0073e9SAndroid Build Coastguard Worker            n_dim = random.randint(1, 8)
671*da0073e9SAndroid Build Coastguard Worker            m_dim = random.randint(1, 8)
672*da0073e9SAndroid Build Coastguard Worker            p_dim = random.randint(1, 8)
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker            def dims_full_for_fn():
675*da0073e9SAndroid Build Coastguard Worker                if fn == "baddbmm":
676*da0073e9SAndroid Build Coastguard Worker                    return (
677*da0073e9SAndroid Build Coastguard Worker                        [batch_dim, n_dim, p_dim],
678*da0073e9SAndroid Build Coastguard Worker                        [batch_dim, n_dim, m_dim],
679*da0073e9SAndroid Build Coastguard Worker                        [batch_dim, m_dim, p_dim],
680*da0073e9SAndroid Build Coastguard Worker                    )
681*da0073e9SAndroid Build Coastguard Worker                elif fn == "addbmm":
682*da0073e9SAndroid Build Coastguard Worker                    return (
683*da0073e9SAndroid Build Coastguard Worker                        [n_dim, p_dim],
684*da0073e9SAndroid Build Coastguard Worker                        [batch_dim, n_dim, m_dim],
685*da0073e9SAndroid Build Coastguard Worker                        [batch_dim, m_dim, p_dim],
686*da0073e9SAndroid Build Coastguard Worker                    )
687*da0073e9SAndroid Build Coastguard Worker                elif fn == "addmm":
688*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
689*da0073e9SAndroid Build Coastguard Worker                elif fn == "addmv":
690*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim], [n_dim, m_dim], [m_dim])
691*da0073e9SAndroid Build Coastguard Worker                elif fn == "addr":
692*da0073e9SAndroid Build Coastguard Worker                    return ([n_dim, m_dim], [n_dim], [m_dim])
693*da0073e9SAndroid Build Coastguard Worker                else:
694*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError("unknown function")
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
697*da0073e9SAndroid Build Coastguard Worker            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker            t0_small = torch.randn(*t0_dims_small, device=device).float()
700*da0073e9SAndroid Build Coastguard Worker            t1 = torch.randn(*t1_dims, device=device).float()
701*da0073e9SAndroid Build Coastguard Worker            t2 = torch.randn(*t2_dims, device=device).float()
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker            t0_full = t0_small.expand(*t0_dims_full).to(device)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker            fntorch = getattr(torch, fn)
706*da0073e9SAndroid Build Coastguard Worker            r0 = fntorch(t0_small, t1, t2)
707*da0073e9SAndroid Build Coastguard Worker            r1 = fntorch(t0_full, t1, t2)
708*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r0, r1)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
711*da0073e9SAndroid Build Coastguard Worker    def test_strided_mm_bmm(self, device, dtype):
712*da0073e9SAndroid Build Coastguard Worker        # Tests strided view case with stride smaller than corresponding dimension size
713*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
714*da0073e9SAndroid Build Coastguard Worker        new_shape = [2, 2, 2]
715*da0073e9SAndroid Build Coastguard Worker        new_stride = [3, 1, 1]
716*da0073e9SAndroid Build Coastguard Worker        sx = torch.as_strided(x, size=new_shape, stride=new_stride)
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
719*da0073e9SAndroid Build Coastguard Worker        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
720*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch_fn, np_fn, sx)
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
723*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch_fn, np_fn, sx[0])
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
726*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
727*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 20, dtype=torch.float32, device=device)
728*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
729*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "expected .* and .* to have the same dtype, but got:"
730*da0073e9SAndroid Build Coastguard Worker        ):
731*da0073e9SAndroid Build Coastguard Worker            torch.mm(a, b)
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker    def test_matmul_45724(self, device):
734*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/45724
735*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
736*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
737*da0073e9SAndroid Build Coastguard Worker        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
738*da0073e9SAndroid Build Coastguard Worker        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half()
739*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a, b, out=c)
740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c, cpu_result)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    @dtypes(
743*da0073e9SAndroid Build Coastguard Worker        torch.int16,
744*da0073e9SAndroid Build Coastguard Worker        torch.int32,
745*da0073e9SAndroid Build Coastguard Worker        torch.int64,
746*da0073e9SAndroid Build Coastguard Worker        torch.float16,
747*da0073e9SAndroid Build Coastguard Worker        torch.float32,
748*da0073e9SAndroid Build Coastguard Worker        torch.float64,
749*da0073e9SAndroid Build Coastguard Worker    )
750*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
751*da0073e9SAndroid Build Coastguard Worker        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
752*da0073e9SAndroid Build Coastguard Worker        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
753*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
754*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.float32:
755*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
756*da0073e9SAndroid Build Coastguard Worker                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
757*da0073e9SAndroid Build Coastguard Worker        else:
758*da0073e9SAndroid Build Coastguard Worker            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
759*da0073e9SAndroid Build Coastguard Worker            y_ref = torch.bmm(batch1, batch2)
760*da0073e9SAndroid Build Coastguard Worker            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, y_ref)
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
764*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
765*da0073e9SAndroid Build Coastguard Worker        for shape in [[3, 2, 2], [2, 20, 20]]:
766*da0073e9SAndroid Build Coastguard Worker            mat1, mat2 = (
767*da0073e9SAndroid Build Coastguard Worker                torch.randn(shape, dtype=dtype, device=device) for _ in range(2)
768*da0073e9SAndroid Build Coastguard Worker            )
769*da0073e9SAndroid Build Coastguard Worker            inputs = [
770*da0073e9SAndroid Build Coastguard Worker                torch.randn(shape, dtype=dtype, device=device),
771*da0073e9SAndroid Build Coastguard Worker                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
772*da0073e9SAndroid Build Coastguard Worker            ]
773*da0073e9SAndroid Build Coastguard Worker            outs = [
774*da0073e9SAndroid Build Coastguard Worker                None,
775*da0073e9SAndroid Build Coastguard Worker                torch.randn(shape, dtype=dtype, device=device),
776*da0073e9SAndroid Build Coastguard Worker                torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
777*da0073e9SAndroid Build Coastguard Worker            ]
778*da0073e9SAndroid Build Coastguard Worker            options = itertools.product(inputs, outs)
779*da0073e9SAndroid Build Coastguard Worker            for input, out in options:
780*da0073e9SAndroid Build Coastguard Worker                y_ref = torch.bmm(mat1, mat2)
781*da0073e9SAndroid Build Coastguard Worker                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
782*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_ref, y)
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
785*da0073e9SAndroid Build Coastguard Worker    def test_addmm_sizes(self, device, dtype):
786*da0073e9SAndroid Build Coastguard Worker        for m in [0, 1, 25]:
787*da0073e9SAndroid Build Coastguard Worker            for n in [0, 1, 10]:
788*da0073e9SAndroid Build Coastguard Worker                for k in [0, 1, 8]:
789*da0073e9SAndroid Build Coastguard Worker                    M = torch.randn(n, m, device=device).to(dtype)
790*da0073e9SAndroid Build Coastguard Worker                    m1 = torch.randn(n, k, device=device).to(dtype)
791*da0073e9SAndroid Build Coastguard Worker                    m2 = torch.randn(k, m, device=device).to(dtype)
792*da0073e9SAndroid Build Coastguard Worker                    self._test_addmm_addmv(torch.addmm, M, m1, m2)
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
795*da0073e9SAndroid Build Coastguard Worker                    m2 = torch.randn(k, m, device=device).to(dtype)
796*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(
797*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
798*da0073e9SAndroid Build Coastguard Worker                        f"{n}x{k + 1}.*{k}x{m}",
799*da0073e9SAndroid Build Coastguard Worker                        lambda: torch.addmm(M, m1, m2),
800*da0073e9SAndroid Build Coastguard Worker                    )
801*da0073e9SAndroid Build Coastguard Worker                    self.assertRaisesRegex(
802*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)
803*da0073e9SAndroid Build Coastguard Worker                    )
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
806*da0073e9SAndroid Build Coastguard Worker        {
807*da0073e9SAndroid Build Coastguard Worker            torch.double: 1e-8,
808*da0073e9SAndroid Build Coastguard Worker            torch.float: 1e-4,
809*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: 5e-2,
810*da0073e9SAndroid Build Coastguard Worker            torch.half: 5e-2,
811*da0073e9SAndroid Build Coastguard Worker            torch.cfloat: 1e-4,
812*da0073e9SAndroid Build Coastguard Worker            torch.cdouble: 1e-8,
813*da0073e9SAndroid Build Coastguard Worker        }
814*da0073e9SAndroid Build Coastguard Worker    )
815*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16, torch.half)
816*da0073e9SAndroid Build Coastguard Worker    def test_addmm_gelu(self, device, dtype):
817*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
820*da0073e9SAndroid Build Coastguard Worker        {
821*da0073e9SAndroid Build Coastguard Worker            torch.double: 1e-8,
822*da0073e9SAndroid Build Coastguard Worker            torch.float: 1e-4,
823*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: 5e-2,
824*da0073e9SAndroid Build Coastguard Worker            torch.half: 5e-2,
825*da0073e9SAndroid Build Coastguard Worker            torch.cfloat: 1e-4,
826*da0073e9SAndroid Build Coastguard Worker            torch.cdouble: 1e-8,
827*da0073e9SAndroid Build Coastguard Worker        }
828*da0073e9SAndroid Build Coastguard Worker    )
829*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.bfloat16, torch.half)
830*da0073e9SAndroid Build Coastguard Worker    def test_addmm_relu(self, device, dtype):
831*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.bfloat16, torch.half)
834*da0073e9SAndroid Build Coastguard Worker    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
835*da0073e9SAndroid Build Coastguard Worker        # tests (o, s)*(s).  o is output size, s is summed size.
836*da0073e9SAndroid Build Coastguard Worker        o = 5
837*da0073e9SAndroid Build Coastguard Worker        s = 3
838*da0073e9SAndroid Build Coastguard Worker        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
839*da0073e9SAndroid Build Coastguard Worker        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
840*da0073e9SAndroid Build Coastguard Worker        y_data = torch.ones(o, device=device, dtype=dtype)
841*da0073e9SAndroid Build Coastguard Worker        control = torch.tensor(
842*da0073e9SAndroid Build Coastguard Worker            [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype
843*da0073e9SAndroid Build Coastguard Worker        )
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker        def _test(row_major, incx, incy, lda_tail):
846*da0073e9SAndroid Build Coastguard Worker            if row_major:
847*da0073e9SAndroid Build Coastguard Worker                a_storage = torch.full(
848*da0073e9SAndroid Build Coastguard Worker                    (o, s + lda_tail), float("nan"), device=device, dtype=dtype
849*da0073e9SAndroid Build Coastguard Worker                )
850*da0073e9SAndroid Build Coastguard Worker            else:
851*da0073e9SAndroid Build Coastguard Worker                a_storage = torch.full(
852*da0073e9SAndroid Build Coastguard Worker                    (s, o + lda_tail), float("nan"), device=device, dtype=dtype
853*da0073e9SAndroid Build Coastguard Worker                ).permute(1, 0)
854*da0073e9SAndroid Build Coastguard Worker            a = a_storage[:o, :s].copy_(a_data)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker            x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype)
857*da0073e9SAndroid Build Coastguard Worker            x = x_storage[:, 0].copy_(x_data)
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker            y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype)
860*da0073e9SAndroid Build Coastguard Worker            y = y_storage[:, 0].copy_(y_data)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker            self._test_addmm_addmv(torch.addmv, y, a, x)
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker        for row_major, incx, incy, lda_tail in itertools.product(
865*da0073e9SAndroid Build Coastguard Worker            (False, True), (1, 2), (1, 2), (0, 1)
866*da0073e9SAndroid Build Coastguard Worker        ):
867*da0073e9SAndroid Build Coastguard Worker            _test(row_major, incx, incy, lda_tail)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    @precisionOverride(
870*da0073e9SAndroid Build Coastguard Worker        {
871*da0073e9SAndroid Build Coastguard Worker            torch.double: 1e-8,
872*da0073e9SAndroid Build Coastguard Worker            torch.float: 1e-4,
873*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: 0.6,
874*da0073e9SAndroid Build Coastguard Worker            torch.half: 1e-1,
875*da0073e9SAndroid Build Coastguard Worker            torch.cfloat: 1e-4,
876*da0073e9SAndroid Build Coastguard Worker            torch.cdouble: 1e-8,
877*da0073e9SAndroid Build Coastguard Worker        }
878*da0073e9SAndroid Build Coastguard Worker    )
879*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half, torch.float32)
880*da0073e9SAndroid Build Coastguard Worker    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
881*da0073e9SAndroid Build Coastguard Worker        # common case
882*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(128, device=device).to(dtype)
883*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(2048, 2400, device=device).to(dtype)
884*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(128, 2400, device=device).to(dtype)
885*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.linear(m1, m2, M)
886*da0073e9SAndroid Build Coastguard Worker        # Ntrans_B has ld >> rows
887*da0073e9SAndroid Build Coastguard Worker        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
888*da0073e9SAndroid Build Coastguard Worker        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
889*da0073e9SAndroid Build Coastguard Worker        M = torch.rand([128]).to(dtype).to(device)
890*da0073e9SAndroid Build Coastguard Worker        torch.addmm(M, m2.t(), m1)
891*da0073e9SAndroid Build Coastguard Worker        # trans_A has ld >> rows
892*da0073e9SAndroid Build Coastguard Worker        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
893*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(2048, 2400, device=device).to(dtype)
894*da0073e9SAndroid Build Coastguard Worker        M = torch.rand([128]).to(dtype).to(device)
895*da0073e9SAndroid Build Coastguard Worker        torch.addmm(M, m2, m1)
896*da0073e9SAndroid Build Coastguard Worker        # large tensor dim > 65535
897*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(16, device=device).to(dtype)
898*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(32, 131071, device=device).to(dtype)
899*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(16, 131071, device=device).to(dtype)
900*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.linear(m1, m2, M)
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker    def test_blas_empty(self, device):
903*da0073e9SAndroid Build Coastguard Worker        def fn(torchfn, *args, test_out=False, **kwargs):
904*da0073e9SAndroid Build Coastguard Worker            def call_torch_fn(*args, **kwargs):
905*da0073e9SAndroid Build Coastguard Worker                return torchfn(
906*da0073e9SAndroid Build Coastguard Worker                    *tuple(
907*da0073e9SAndroid Build Coastguard Worker                        torch.randn(shape, device=device)
908*da0073e9SAndroid Build Coastguard Worker                        if isinstance(shape, tuple)
909*da0073e9SAndroid Build Coastguard Worker                        else shape
910*da0073e9SAndroid Build Coastguard Worker                        for shape in args
911*da0073e9SAndroid Build Coastguard Worker                    ),
912*da0073e9SAndroid Build Coastguard Worker                    **kwargs,
913*da0073e9SAndroid Build Coastguard Worker                )
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker            result = call_torch_fn(*args, **kwargs)
916*da0073e9SAndroid Build Coastguard Worker            if not test_out:
917*da0073e9SAndroid Build Coastguard Worker                return result
918*da0073e9SAndroid Build Coastguard Worker            else:
919*da0073e9SAndroid Build Coastguard Worker                out = torch.full_like(result, math.nan)
920*da0073e9SAndroid Build Coastguard Worker                out1 = call_torch_fn(*args, **kwargs, out=out)
921*da0073e9SAndroid Build Coastguard Worker                return out
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker        # mm, addmm
924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
928*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
929*da0073e9SAndroid Build Coastguard Worker            torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))
930*da0073e9SAndroid Build Coastguard Worker        )
931*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
932*da0073e9SAndroid Build Coastguard Worker            torch.zeros((5, 6), device=device),
933*da0073e9SAndroid Build Coastguard Worker            fn(torch.mm, (5, 0), (0, 6), test_out=True),
934*da0073e9SAndroid Build Coastguard Worker        )
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
937*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape)
938*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((5, 6), device=device)
939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
940*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker        # mv, addmv
943*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
944*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
945*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
946*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
947*da0073e9SAndroid Build Coastguard Worker            torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)
948*da0073e9SAndroid Build Coastguard Worker        )
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
951*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((3,), device=device)
952*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
953*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        # bmm, baddbmm
956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
958*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
960*da0073e9SAndroid Build Coastguard Worker            torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))
961*da0073e9SAndroid Build Coastguard Worker        )
962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
963*da0073e9SAndroid Build Coastguard Worker            torch.zeros((3, 5, 6), device=device),
964*da0073e9SAndroid Build Coastguard Worker            fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True),
965*da0073e9SAndroid Build Coastguard Worker        )
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
968*da0073e9SAndroid Build Coastguard Worker            (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape
969*da0073e9SAndroid Build Coastguard Worker        )
970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
971*da0073e9SAndroid Build Coastguard Worker            (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape
972*da0073e9SAndroid Build Coastguard Worker        )
973*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
974*da0073e9SAndroid Build Coastguard Worker            (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape
975*da0073e9SAndroid Build Coastguard Worker        )
976*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
977*da0073e9SAndroid Build Coastguard Worker            (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape
978*da0073e9SAndroid Build Coastguard Worker        )
979*da0073e9SAndroid Build Coastguard Worker        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
980*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
981*da0073e9SAndroid Build Coastguard Worker            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)
982*da0073e9SAndroid Build Coastguard Worker        )  # Issue #33467
983*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
984*da0073e9SAndroid Build Coastguard Worker            -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)
985*da0073e9SAndroid Build Coastguard Worker        )  # Issue #33467
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker        # addbmm
988*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
990*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((5, 6), device=device)
991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker        # matmul
995*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,)))
996*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
997*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0.0, device=device),
998*da0073e9SAndroid Build Coastguard Worker            fn(torch.matmul, (0,), (0,), test_out=True),
999*da0073e9SAndroid Build Coastguard Worker        )
1000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
1001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
1002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
1003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1004*da0073e9SAndroid Build Coastguard Worker            torch.zeros((5, 3, 4), device=device),
1005*da0073e9SAndroid Build Coastguard Worker            fn(torch.matmul, (5, 3, 0), (5, 0, 4)),
1006*da0073e9SAndroid Build Coastguard Worker        )
1007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1008*da0073e9SAndroid Build Coastguard Worker            torch.zeros((5, 3, 4), device=device),
1009*da0073e9SAndroid Build Coastguard Worker            fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True),
1010*da0073e9SAndroid Build Coastguard Worker        )
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker        # dot
1013*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,)))
1014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1015*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
1016*da0073e9SAndroid Build Coastguard Worker        )
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker    def test_large_bmm_backward(self, device):
1019*da0073e9SAndroid Build Coastguard Worker        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
1020*da0073e9SAndroid Build Coastguard Worker        B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
1021*da0073e9SAndroid Build Coastguard Worker        G = torch.randn([1024, 2, 65536], device=device)
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
1024*da0073e9SAndroid Build Coastguard Worker        (A @ B).backward(G)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    def test_large_bmm_mm_backward(self, device):
1027*da0073e9SAndroid Build Coastguard Worker        A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
1028*da0073e9SAndroid Build Coastguard Worker        B = torch.randn([1024, 65536], device=device, requires_grad=True)
1029*da0073e9SAndroid Build Coastguard Worker        G = torch.randn([1024, 2, 65536], device=device)
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
1032*da0073e9SAndroid Build Coastguard Worker        (A @ B).backward(G)
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker    def check_single_matmul(self, x, y):
1035*da0073e9SAndroid Build Coastguard Worker        def assertEqual(answer, expected):
1036*da0073e9SAndroid Build Coastguard Worker            if x.dtype.is_floating_point or x.dtype.is_complex:
1037*da0073e9SAndroid Build Coastguard Worker                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
1038*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1039*da0073e9SAndroid Build Coastguard Worker                    answer,
1040*da0073e9SAndroid Build Coastguard Worker                    expected,
1041*da0073e9SAndroid Build Coastguard Worker                    msg=f"{x.shape} x {y.shape} = {answer.shape}",
1042*da0073e9SAndroid Build Coastguard Worker                    atol=k * 5e-5,
1043*da0073e9SAndroid Build Coastguard Worker                    rtol=1e-4,
1044*da0073e9SAndroid Build Coastguard Worker                )
1045*da0073e9SAndroid Build Coastguard Worker            else:
1046*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1047*da0073e9SAndroid Build Coastguard Worker                    answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}"
1048*da0073e9SAndroid Build Coastguard Worker                )
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        # test x @ y
1051*da0073e9SAndroid Build Coastguard Worker        expected = np.matmul(x.cpu(), y.cpu())
1052*da0073e9SAndroid Build Coastguard Worker        ans = torch.matmul(x, y)
1053*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ans.is_contiguous())
1054*da0073e9SAndroid Build Coastguard Worker        assertEqual(ans, expected)
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker        # test out
1057*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(ans)
1058*da0073e9SAndroid Build Coastguard Worker        ans = torch.matmul(x, y, out=out)
1059*da0073e9SAndroid Build Coastguard Worker        self.assertIs(ans, out)
1060*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ans.is_contiguous())
1061*da0073e9SAndroid Build Coastguard Worker        assertEqual(ans, expected)
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
1064*da0073e9SAndroid Build Coastguard Worker        """
1065*da0073e9SAndroid Build Coastguard Worker        Generates sequences of tuples (x, y) of with size(x) = x_dim and
1066*da0073e9SAndroid Build Coastguard Worker        size(y) <= y_dim that are compatible wrt. matmul
1067*da0073e9SAndroid Build Coastguard Worker        """
1068*da0073e9SAndroid Build Coastguard Worker        assert x_dim >= 1
1069*da0073e9SAndroid Build Coastguard Worker        assert y_dim >= 2
1070*da0073e9SAndroid Build Coastguard Worker        x = x_dim
1071*da0073e9SAndroid Build Coastguard Worker        for y in range(1, y_dim + 1):
1072*da0073e9SAndroid Build Coastguard Worker            for batch, mn in product(
1073*da0073e9SAndroid Build Coastguard Worker                product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
1074*da0073e9SAndroid Build Coastguard Worker                product(range(matrix_size), repeat=min(y, 2)),
1075*da0073e9SAndroid Build Coastguard Worker            ):
1076*da0073e9SAndroid Build Coastguard Worker                if x == 1:
1077*da0073e9SAndroid Build Coastguard Worker                    size_x = mn[:1]
1078*da0073e9SAndroid Build Coastguard Worker                    size_y = batch + mn
1079*da0073e9SAndroid Build Coastguard Worker                    yield size_x, size_y
1080*da0073e9SAndroid Build Coastguard Worker                else:
1081*da0073e9SAndroid Build Coastguard Worker                    for k in range(matrix_size):
1082*da0073e9SAndroid Build Coastguard Worker                        size_x = (k,) + mn[:1]
1083*da0073e9SAndroid Build Coastguard Worker                        if x > 2:
1084*da0073e9SAndroid Build Coastguard Worker                            size_x = batch[-(x - 2) :] + size_x
1085*da0073e9SAndroid Build Coastguard Worker                        size_y = mn
1086*da0073e9SAndroid Build Coastguard Worker                        if y > 2:
1087*da0073e9SAndroid Build Coastguard Worker                            size_y = batch[-(y - 2) :] + size_y
1088*da0073e9SAndroid Build Coastguard Worker                        yield size_x, size_y
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1091*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
1092*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker        for (size_x, size_y), nctg_x, nctg_y in product(
1095*da0073e9SAndroid Build Coastguard Worker            self.gen_sizes_matmul(1), (True, False), (True, False)
1096*da0073e9SAndroid Build Coastguard Worker        ):
1097*da0073e9SAndroid Build Coastguard Worker            x = make_arg(size_x, noncontiguous=nctg_x)
1098*da0073e9SAndroid Build Coastguard Worker            y = make_arg(size_y, noncontiguous=nctg_y)
1099*da0073e9SAndroid Build Coastguard Worker            self.check_single_matmul(x, y)
1100*da0073e9SAndroid Build Coastguard Worker
1101*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1102*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
1103*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker        for (size_x, size_y), nctg_x, nctg_y in product(
1106*da0073e9SAndroid Build Coastguard Worker            self.gen_sizes_matmul(2), (True, False), (True, False)
1107*da0073e9SAndroid Build Coastguard Worker        ):
1108*da0073e9SAndroid Build Coastguard Worker            x = make_arg(size_x, noncontiguous=nctg_x)
1109*da0073e9SAndroid Build Coastguard Worker            y = make_arg(size_y, noncontiguous=nctg_y)
1110*da0073e9SAndroid Build Coastguard Worker            self.check_single_matmul(x, y)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1113*da0073e9SAndroid Build Coastguard Worker    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
1114*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker        for (size_x, size_y), nctg_x, nctg_y in product(
1117*da0073e9SAndroid Build Coastguard Worker            self.gen_sizes_matmul(3), (True, False), (True, False)
1118*da0073e9SAndroid Build Coastguard Worker        ):
1119*da0073e9SAndroid Build Coastguard Worker            x = make_arg(size_x, noncontiguous=nctg_x)
1120*da0073e9SAndroid Build Coastguard Worker            y = make_arg(size_y, noncontiguous=nctg_y)
1121*da0073e9SAndroid Build Coastguard Worker            self.check_single_matmul(x, y)
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1124*da0073e9SAndroid Build Coastguard Worker    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
1125*da0073e9SAndroid Build Coastguard Worker        a = torch.empty(
1126*da0073e9SAndroid Build Coastguard Worker            (256, 512), device=device, dtype=dtype, requires_grad=True
1127*da0073e9SAndroid Build Coastguard Worker        ).unsqueeze(0)
1128*da0073e9SAndroid Build Coastguard Worker        b = torch.empty(
1129*da0073e9SAndroid Build Coastguard Worker            (4, 128, 512), device=device, dtype=dtype, requires_grad=True
1130*da0073e9SAndroid Build Coastguard Worker        ).transpose(-1, -2)
1131*da0073e9SAndroid Build Coastguard Worker        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Worker        torch.matmul(a.detach(), b.detach(), out=c)
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1136*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1137*da0073e9SAndroid Build Coastguard Worker            "functions with out=... arguments don't support automatic differentiation",
1138*da0073e9SAndroid Build Coastguard Worker        ):
1139*da0073e9SAndroid Build Coastguard Worker            torch.matmul(a, b, out=c)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1142*da0073e9SAndroid Build Coastguard Worker            torch.matmul(a, b, out=c)
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1148*da0073e9SAndroid Build Coastguard Worker    run_tests()
1149