1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: intel"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport random 6*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 7*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport numpy as np 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 14*da0073e9SAndroid Build Coastguard Worker dtypes, 15*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 16*da0073e9SAndroid Build Coastguard Worker precisionOverride, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import iter_indices, run_tests, TestCase 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass TestBasicGEMM(TestCase): 22*da0073e9SAndroid Build Coastguard Worker def _test_addmm_addmv( 23*da0073e9SAndroid Build Coastguard Worker self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None 24*da0073e9SAndroid Build Coastguard Worker ): 25*da0073e9SAndroid Build Coastguard Worker dtype = t.dtype 26*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype 27*da0073e9SAndroid Build Coastguard Worker if dtype in {torch.bfloat16, torch.half}: 28*da0073e9SAndroid Build Coastguard Worker numpy_dtype = torch.float 29*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 30*da0073e9SAndroid Build Coastguard Worker alpha = 0.9 + 0.3j if alpha is None else alpha 31*da0073e9SAndroid Build Coastguard Worker beta = 0.5 + 0.6j if beta is None else beta 32*da0073e9SAndroid Build Coastguard Worker else: 33*da0073e9SAndroid Build Coastguard Worker alpha = 1.2 if alpha is None else alpha 34*da0073e9SAndroid Build Coastguard Worker beta = 0.8 if beta is None else beta 35*da0073e9SAndroid Build Coastguard Worker if activation == "gelu": 36*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) 37*da0073e9SAndroid Build Coastguard Worker else: 38*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta) 39*da0073e9SAndroid Build Coastguard Worker res2 = torch.full_like(res1, math.nan) 40*da0073e9SAndroid Build Coastguard Worker if transpose_out: 41*da0073e9SAndroid Build Coastguard Worker res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 42*da0073e9SAndroid Build Coastguard Worker if activation == "gelu": 43*da0073e9SAndroid Build Coastguard Worker f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) 44*da0073e9SAndroid Build Coastguard Worker else: 45*da0073e9SAndroid Build Coastguard Worker f(t, m, v, alpha=alpha, beta=beta, out=res2) 46*da0073e9SAndroid Build Coastguard Worker m.to(numpy_dtype).cpu().numpy() 47*da0073e9SAndroid Build Coastguard Worker v.to(numpy_dtype).cpu().numpy() 48*da0073e9SAndroid Build Coastguard Worker res3 = alpha * ( 49*da0073e9SAndroid Build Coastguard Worker m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy() 50*da0073e9SAndroid Build Coastguard Worker ) 51*da0073e9SAndroid Build Coastguard Worker if beta != 0: 52*da0073e9SAndroid Build Coastguard Worker res3 += (beta * t).to(numpy_dtype).cpu().numpy() 53*da0073e9SAndroid Build Coastguard Worker if activation == "relu": 54*da0073e9SAndroid Build Coastguard Worker res3 = res3 * (res3 > 0) 55*da0073e9SAndroid Build Coastguard Worker elif activation == "gelu": 56*da0073e9SAndroid Build Coastguard Worker res3_t = torch.from_numpy(res3).to(dtype) 57*da0073e9SAndroid Build Coastguard Worker approximate = "tanh" if t.is_cuda else "none" 58*da0073e9SAndroid Build Coastguard Worker res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) 59*da0073e9SAndroid Build Coastguard Worker res3 = res3_t.to(numpy_dtype).cpu().numpy() 60*da0073e9SAndroid Build Coastguard Worker else: 61*da0073e9SAndroid Build Coastguard Worker assert activation is None, f"unsupported activation {activation}" 62*da0073e9SAndroid Build Coastguard Worker res3 = torch.from_numpy(res3).to(dtype) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 64*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def _test_addmm_impl(self, func, activation, device, dtype): 67*da0073e9SAndroid Build Coastguard Worker M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 68*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) 69*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 70*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, activation=activation) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker # vector-shaped bias and beta=1 result in epilogue fusion in CUDA 73*da0073e9SAndroid Build Coastguard Worker V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device) 74*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker # Test 0-strided 77*da0073e9SAndroid Build Coastguard Worker M = ( 78*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 1, device="cpu", dtype=torch.float32) 79*da0073e9SAndroid Build Coastguard Worker .to(dtype) 80*da0073e9SAndroid Build Coastguard Worker .expand(10, 25) 81*da0073e9SAndroid Build Coastguard Worker .to(device) 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker m1 = ( 84*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 1, device="cpu", dtype=torch.float32) 85*da0073e9SAndroid Build Coastguard Worker .to(dtype) 86*da0073e9SAndroid Build Coastguard Worker .expand(10, 50) 87*da0073e9SAndroid Build Coastguard Worker .to(device) 88*da0073e9SAndroid Build Coastguard Worker ) 89*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 90*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, activation=activation) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker # Test beta=0, M=nan 93*da0073e9SAndroid Build Coastguard Worker M = ( 94*da0073e9SAndroid Build Coastguard Worker torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32) 95*da0073e9SAndroid Build Coastguard Worker .to(dtype) 96*da0073e9SAndroid Build Coastguard Worker .to(device) 97*da0073e9SAndroid Build Coastguard Worker ) 98*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) 99*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) 100*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # Test transpose 103*da0073e9SAndroid Build Coastguard Worker for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker def maybe_transpose(cond, m): 106*da0073e9SAndroid Build Coastguard Worker if not cond: 107*da0073e9SAndroid Build Coastguard Worker return m 108*da0073e9SAndroid Build Coastguard Worker return m.t().clone(memory_format=torch.contiguous_format).t() 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 111*da0073e9SAndroid Build Coastguard Worker m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 112*da0073e9SAndroid Build Coastguard Worker m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 113*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv( 114*da0073e9SAndroid Build Coastguard Worker func, M, m1, m2, transpose_out=t4, activation=activation 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker if t1: 118*da0073e9SAndroid Build Coastguard Worker # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) 119*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv( 120*da0073e9SAndroid Build Coastguard Worker func, 121*da0073e9SAndroid Build Coastguard Worker V, 122*da0073e9SAndroid Build Coastguard Worker m1, 123*da0073e9SAndroid Build Coastguard Worker m2, 124*da0073e9SAndroid Build Coastguard Worker beta=1, 125*da0073e9SAndroid Build Coastguard Worker transpose_out=t4, 126*da0073e9SAndroid Build Coastguard Worker activation=activation, 127*da0073e9SAndroid Build Coastguard Worker ) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 130*da0073e9SAndroid Build Coastguard Worker { 131*da0073e9SAndroid Build Coastguard Worker torch.float: 1e-4, 132*da0073e9SAndroid Build Coastguard Worker torch.half: 1e-1, 133*da0073e9SAndroid Build Coastguard Worker } 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.half) 136*da0073e9SAndroid Build Coastguard Worker def test_addmm(self, device, dtype): 137*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch.addmm, None, device, dtype) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4}) 140*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half, torch.float) 141*da0073e9SAndroid Build Coastguard Worker def test_addmv(self, device, dtype): 142*da0073e9SAndroid Build Coastguard Worker # have to use torch.randn(...).to(bfloat16) instead of 143*da0073e9SAndroid Build Coastguard Worker # torch.randn(..., dtype=bfloat16). randn does not support 144*da0073e9SAndroid Build Coastguard Worker # bfloat16 yet. 145*da0073e9SAndroid Build Coastguard Worker # "*0.2" to reduce errors for low precision 146*da0073e9SAndroid Build Coastguard Worker ts = [ 147*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(50, device=device).to(dtype), 148*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(1, device=device).to(dtype).expand(50), 149*da0073e9SAndroid Build Coastguard Worker ] 150*da0073e9SAndroid Build Coastguard Worker vs = [ 151*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn(100, device=device).to(dtype), 152*da0073e9SAndroid Build Coastguard Worker 0.2 153*da0073e9SAndroid Build Coastguard Worker * torch.ones(1, device=device) 154*da0073e9SAndroid Build Coastguard Worker .to(dtype) 155*da0073e9SAndroid Build Coastguard Worker .expand(100), # to reduce errors for low precision 156*da0073e9SAndroid Build Coastguard Worker ] 157*da0073e9SAndroid Build Coastguard Worker ms = [ 158*da0073e9SAndroid Build Coastguard Worker # 0d 159*da0073e9SAndroid Build Coastguard Worker 0.2 160*da0073e9SAndroid Build Coastguard Worker * torch.ones((), device=device) 161*da0073e9SAndroid Build Coastguard Worker .to(dtype) 162*da0073e9SAndroid Build Coastguard Worker .expand(50, 100), # to reduce errors for low precision 163*da0073e9SAndroid Build Coastguard Worker # 1d 164*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), 165*da0073e9SAndroid Build Coastguard Worker # this initialization reduces errors for low precision for broadcasted matrices 166*da0073e9SAndroid Build Coastguard Worker # by making sure that intermediate and result values are exactly representable 167*da0073e9SAndroid Build Coastguard Worker # in low precision type 168*da0073e9SAndroid Build Coastguard Worker 0.2 169*da0073e9SAndroid Build Coastguard Worker * torch.randint(3, (50, 1), dtype=torch.float, device=device) 170*da0073e9SAndroid Build Coastguard Worker .to(dtype) 171*da0073e9SAndroid Build Coastguard Worker .expand(50, 100), 172*da0073e9SAndroid Build Coastguard Worker # 2d 173*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((50, 100), device=device).to(dtype), 174*da0073e9SAndroid Build Coastguard Worker 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), 175*da0073e9SAndroid Build Coastguard Worker ] 176*da0073e9SAndroid Build Coastguard Worker for m, v, t in itertools.product(ms, vs, ts): 177*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, t, m, v) 178*da0073e9SAndroid Build Coastguard Worker # Test beta=0, t=nan 179*da0073e9SAndroid Build Coastguard Worker t = torch.full((50,), math.nan, device=device).to(dtype) 180*da0073e9SAndroid Build Coastguard Worker for m, v in itertools.product(ms, vs): 181*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker @dtypes( 184*da0073e9SAndroid Build Coastguard Worker torch.half, 185*da0073e9SAndroid Build Coastguard Worker torch.float32, 186*da0073e9SAndroid Build Coastguard Worker ) 187*da0073e9SAndroid Build Coastguard Worker def test_mm(self, device, dtype): 188*da0073e9SAndroid Build Coastguard Worker def _test_mm(n, m, p, dtype, genf): 189*da0073e9SAndroid Build Coastguard Worker # helper function 190*da0073e9SAndroid Build Coastguard Worker def matrixmultiply(mat1, mat2): 191*da0073e9SAndroid Build Coastguard Worker n = mat1.size(0) 192*da0073e9SAndroid Build Coastguard Worker m = mat1.size(1) 193*da0073e9SAndroid Build Coastguard Worker p = mat2.size(1) 194*da0073e9SAndroid Build Coastguard Worker dtype_ = torch.float if dtype == torch.half else dtype 195*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 196*da0073e9SAndroid Build Coastguard Worker mat1 = mat1.float() 197*da0073e9SAndroid Build Coastguard Worker mat2 = mat2.float() 198*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(n, p, dtype=dtype_, device=device) 199*da0073e9SAndroid Build Coastguard Worker for i, j in iter_indices(res): 200*da0073e9SAndroid Build Coastguard Worker res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) 201*da0073e9SAndroid Build Coastguard Worker return res.half() if dtype == torch.half else res 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker # contiguous case 204*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 205*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 206*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker # non contiguous case 1 212*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 213*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 214*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker # non contiguous case 2 220*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 221*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 222*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker # non contiguous case 3 228*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 229*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 230*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker # test with zero stride 236*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 237*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, 1).expand(m, p) 238*da0073e9SAndroid Build Coastguard Worker res = torch.mm(mat1, mat2) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker # explicitly exercise the _out variant in torch.mm(). 244*da0073e9SAndroid Build Coastguard Worker # contiguous case 245*da0073e9SAndroid Build Coastguard Worker mat1 = genf(n, m) 246*da0073e9SAndroid Build Coastguard Worker mat2 = genf(m, p) 247*da0073e9SAndroid Build Coastguard Worker res = genf(n, p) 248*da0073e9SAndroid Build Coastguard Worker torch.mm(mat1, mat2, out=res) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker # explicitly exercise the _out variant in torch.mm(). 254*da0073e9SAndroid Build Coastguard Worker # non contiguous case 3 255*da0073e9SAndroid Build Coastguard Worker mat1 = genf(m, n).t() 256*da0073e9SAndroid Build Coastguard Worker mat2 = genf(p, m).t() 257*da0073e9SAndroid Build Coastguard Worker res = genf(n, p) 258*da0073e9SAndroid Build Coastguard Worker torch.mm(mat1, mat2, out=res) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker res2 = matrixmultiply(mat1, mat2) 261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res2) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def genf_int(x, y): 264*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 100, (x, y), dtype=dtype, device=device) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker def genf_bfloat(x, y): 267*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker def genf_float(x, y): 270*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=dtype, device=device) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker def genf_Half(x, y): 273*da0073e9SAndroid Build Coastguard Worker return torch.randn(x, y, dtype=dtype, device=device) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: 276*da0073e9SAndroid Build Coastguard Worker if (dtype == torch.int32) or (dtype == torch.int64): 277*da0073e9SAndroid Build Coastguard Worker genf = genf_int 278*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bfloat16: 279*da0073e9SAndroid Build Coastguard Worker genf = genf_bfloat 280*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.half: 281*da0073e9SAndroid Build Coastguard Worker genf = genf_Half 282*da0073e9SAndroid Build Coastguard Worker else: 283*da0073e9SAndroid Build Coastguard Worker genf = genf_float 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker _test_mm(n, m, p, dtype, genf) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 288*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16, torch.half) 289*da0073e9SAndroid Build Coastguard Worker def test_bmm(self, device, dtype): 290*da0073e9SAndroid Build Coastguard Worker batch_sizes = [1, 10] 291*da0073e9SAndroid Build Coastguard Worker M, N, O = 23, 15, 12 292*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 295*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 296*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker def generate_inputs(num_batches): 299*da0073e9SAndroid Build Coastguard Worker # transposed tensors 300*da0073e9SAndroid Build Coastguard Worker for perm1, perm2 in itertools.product( 301*da0073e9SAndroid Build Coastguard Worker itertools.permutations((0, 1, 2)), repeat=2 302*da0073e9SAndroid Build Coastguard Worker ): 303*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor( 304*da0073e9SAndroid Build Coastguard Worker (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1 305*da0073e9SAndroid Build Coastguard Worker ) 306*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor( 307*da0073e9SAndroid Build Coastguard Worker (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 310*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 311*da0073e9SAndroid Build Coastguard Worker yield b1, b2 312*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 313*da0073e9SAndroid Build Coastguard Worker for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): 314*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) 315*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) 316*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor( 317*da0073e9SAndroid Build Coastguard Worker shape1, dtype=dtype, device=device, low=-0.1, high=0.1 318*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, M, N) 319*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor( 320*da0073e9SAndroid Build Coastguard Worker shape2, dtype=dtype, device=device, low=-0.1, high=0.1 321*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, N, O) 322*da0073e9SAndroid Build Coastguard Worker yield b1, b2 323*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 324*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 325*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 326*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 327*da0073e9SAndroid Build Coastguard Worker b1 = torch.randn(shape1, dtype=dtype, device=device) 328*da0073e9SAndroid Build Coastguard Worker b2 = torch.randn(shape2, dtype=dtype, device=device) 329*da0073e9SAndroid Build Coastguard Worker yield b1, b2 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker for num_batches in batch_sizes: 332*da0073e9SAndroid Build Coastguard Worker for (b1, b2), perm3 in itertools.product( 333*da0073e9SAndroid Build Coastguard Worker generate_inputs(num_batches), itertools.permutations((0, 1, 2)) 334*da0073e9SAndroid Build Coastguard Worker ): 335*da0073e9SAndroid Build Coastguard Worker res1 = torch.bmm(b1, b2) 336*da0073e9SAndroid Build Coastguard Worker res2 = ( 337*da0073e9SAndroid Build Coastguard Worker torch.full( 338*da0073e9SAndroid Build Coastguard Worker (num_batches, M, O), math.nan, dtype=dtype, device=device 339*da0073e9SAndroid Build Coastguard Worker ) 340*da0073e9SAndroid Build Coastguard Worker .permute(perm3) 341*da0073e9SAndroid Build Coastguard Worker .contiguous() 342*da0073e9SAndroid Build Coastguard Worker .permute(invert_perm(perm3)) 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker torch.bmm(b1, b2, out=res2) 345*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy( 346*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 347*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res1) 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res2) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda": 352*da0073e9SAndroid Build Coastguard Worker # check that mixed arguments are rejected 353*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) 354*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) 355*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 356*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()) 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): 360*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(b1, b2) 361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref) 362*da0073e9SAndroid Build Coastguard Worker res3 = out_tensor.clone() 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 365*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func}_ is deprecated" 366*da0073e9SAndroid Build Coastguard Worker ): 367*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(1, b1, b2) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref * 2), 369*da0073e9SAndroid Build Coastguard Worker getattr(res3, func + "_")(b1, b2, beta=1) 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, res3) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 373*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func}_ is deprecated" 374*da0073e9SAndroid Build Coastguard Worker ): 375*da0073e9SAndroid Build Coastguard Worker getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2) 376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, ref * 2.5) 377*da0073e9SAndroid Build Coastguard Worker getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5) 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, res3) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 381*da0073e9SAndroid Build Coastguard Worker UserWarning, f"This overload of {func} is deprecated" 382*da0073e9SAndroid Build Coastguard Worker ): 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5) 386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res4, ref * 3), 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker nan = torch.full_like(out_tensor, math.nan) 389*da0073e9SAndroid Build Coastguard Worker res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res5, ref) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker if b1.is_complex(): 393*da0073e9SAndroid Build Coastguard Worker res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref) 395*da0073e9SAndroid Build Coastguard Worker else: 396*da0073e9SAndroid Build Coastguard Worker res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5) 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker res7 = torch.full_like(out_tensor, math.nan) 400*da0073e9SAndroid Build Coastguard Worker getattr(torch, func)(nan, b1, b2, beta=0, out=res7) 401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res7, ref) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) 404*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16, torch.half) 405*da0073e9SAndroid Build Coastguard Worker def test_addbmm(self, device, dtype): 406*da0073e9SAndroid Build Coastguard Worker num_batches = 2 407*da0073e9SAndroid Build Coastguard Worker M, N, O = 16, 17, 18 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker is_supported = True 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker if not is_supported: 412*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor( 413*da0073e9SAndroid Build Coastguard Worker (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor( 416*da0073e9SAndroid Build Coastguard Worker (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 417*da0073e9SAndroid Build Coastguard Worker ) 418*da0073e9SAndroid Build Coastguard Worker t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) 419*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 420*da0073e9SAndroid Build Coastguard Worker RuntimeError, 421*da0073e9SAndroid Build Coastguard Worker "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", 422*da0073e9SAndroid Build Coastguard Worker lambda: torch.addbmm(t, b1, b2), 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker return 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 427*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 428*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker def generate_tensor(): 431*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 432*da0073e9SAndroid Build Coastguard Worker # transposed tensors 433*da0073e9SAndroid Build Coastguard Worker for perm1, perm2 in itertools.product( 434*da0073e9SAndroid Build Coastguard Worker itertools.permutations((0, 1, 2)), repeat=2 435*da0073e9SAndroid Build Coastguard Worker ): 436*da0073e9SAndroid Build Coastguard Worker for perm3 in itertools.permutations((0, 1)): 437*da0073e9SAndroid Build Coastguard Worker b1 = ( 438*da0073e9SAndroid Build Coastguard Worker make_tensor( 439*da0073e9SAndroid Build Coastguard Worker (num_batches, M, N), 440*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 441*da0073e9SAndroid Build Coastguard Worker device=device, 442*da0073e9SAndroid Build Coastguard Worker low=-1, 443*da0073e9SAndroid Build Coastguard Worker high=1, 444*da0073e9SAndroid Build Coastguard Worker ) 445*da0073e9SAndroid Build Coastguard Worker * 0.1 446*da0073e9SAndroid Build Coastguard Worker ) 447*da0073e9SAndroid Build Coastguard Worker b2 = ( 448*da0073e9SAndroid Build Coastguard Worker make_tensor( 449*da0073e9SAndroid Build Coastguard Worker (num_batches, N, O), 450*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 451*da0073e9SAndroid Build Coastguard Worker device=device, 452*da0073e9SAndroid Build Coastguard Worker low=-1, 453*da0073e9SAndroid Build Coastguard Worker high=1, 454*da0073e9SAndroid Build Coastguard Worker ) 455*da0073e9SAndroid Build Coastguard Worker * 0.1 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 458*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 459*da0073e9SAndroid Build Coastguard Worker ref = ( 460*da0073e9SAndroid Build Coastguard Worker torch.from_numpy( 461*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() 462*da0073e9SAndroid Build Coastguard Worker @ b2.to(numpy_dtype).cpu().numpy() 463*da0073e9SAndroid Build Coastguard Worker ) 464*da0073e9SAndroid Build Coastguard Worker .to(device=device, dtype=dtype) 465*da0073e9SAndroid Build Coastguard Worker .sum(0) 466*da0073e9SAndroid Build Coastguard Worker ) 467*da0073e9SAndroid Build Coastguard Worker out_tensor = ( 468*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 471*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 472*da0073e9SAndroid Build Coastguard Worker for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 473*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 474*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 475*da0073e9SAndroid Build Coastguard Worker b1 = ( 476*da0073e9SAndroid Build Coastguard Worker make_tensor( 477*da0073e9SAndroid Build Coastguard Worker shape1, dtype=dtype, device=device, low=-1, high=1 478*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, M, N) 479*da0073e9SAndroid Build Coastguard Worker * 0.1 480*da0073e9SAndroid Build Coastguard Worker ) 481*da0073e9SAndroid Build Coastguard Worker b2 = ( 482*da0073e9SAndroid Build Coastguard Worker make_tensor( 483*da0073e9SAndroid Build Coastguard Worker shape2, dtype=dtype, device=device, low=-1, high=1 484*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, N, O) 485*da0073e9SAndroid Build Coastguard Worker * 0.1 486*da0073e9SAndroid Build Coastguard Worker ) 487*da0073e9SAndroid Build Coastguard Worker ref = ( 488*da0073e9SAndroid Build Coastguard Worker torch.from_numpy( 489*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() 490*da0073e9SAndroid Build Coastguard Worker @ b2.to(numpy_dtype).cpu().numpy() 491*da0073e9SAndroid Build Coastguard Worker ) 492*da0073e9SAndroid Build Coastguard Worker .to(device=device, dtype=dtype) 493*da0073e9SAndroid Build Coastguard Worker .sum(0) 494*da0073e9SAndroid Build Coastguard Worker ) 495*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 496*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 497*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 498*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 499*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 500*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 501*da0073e9SAndroid Build Coastguard Worker b1 = ( 502*da0073e9SAndroid Build Coastguard Worker make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) 503*da0073e9SAndroid Build Coastguard Worker * 0.1 504*da0073e9SAndroid Build Coastguard Worker ) 505*da0073e9SAndroid Build Coastguard Worker b2 = ( 506*da0073e9SAndroid Build Coastguard Worker make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) 507*da0073e9SAndroid Build Coastguard Worker * 0.1 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker ref = ( 510*da0073e9SAndroid Build Coastguard Worker torch.from_numpy( 511*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() 512*da0073e9SAndroid Build Coastguard Worker @ b2.to(numpy_dtype).cpu().numpy() 513*da0073e9SAndroid Build Coastguard Worker ) 514*da0073e9SAndroid Build Coastguard Worker .to(device=device, dtype=dtype) 515*da0073e9SAndroid Build Coastguard Worker .sum(0) 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 518*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker for b1, b2, ref, out_tensor in generate_tensor(): 521*da0073e9SAndroid Build Coastguard Worker self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) 524*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16, torch.half) 525*da0073e9SAndroid Build Coastguard Worker def test_baddbmm(self, device, dtype): 526*da0073e9SAndroid Build Coastguard Worker num_batches = 10 527*da0073e9SAndroid Build Coastguard Worker M, N, O = 12, 8, 50 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 530*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 531*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker def generate_tensor(): 534*da0073e9SAndroid Build Coastguard Worker numpy_dtype = ( 535*da0073e9SAndroid Build Coastguard Worker dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker # transposed tensors 538*da0073e9SAndroid Build Coastguard Worker for perm1, perm2, perm3 in itertools.product( 539*da0073e9SAndroid Build Coastguard Worker itertools.permutations((0, 1, 2)), repeat=3 540*da0073e9SAndroid Build Coastguard Worker ): 541*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor( 542*da0073e9SAndroid Build Coastguard Worker (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 543*da0073e9SAndroid Build Coastguard Worker ) 544*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor( 545*da0073e9SAndroid Build Coastguard Worker (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 546*da0073e9SAndroid Build Coastguard Worker ) 547*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 548*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 549*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 550*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 551*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 552*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 553*da0073e9SAndroid Build Coastguard Worker out_tensor = ( 554*da0073e9SAndroid Build Coastguard Worker out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) 555*da0073e9SAndroid Build Coastguard Worker ) 556*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 557*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 558*da0073e9SAndroid Build Coastguard Worker for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): 559*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) 560*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) 561*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor( 562*da0073e9SAndroid Build Coastguard Worker shape1, dtype=dtype, device=device, low=-1, high=1 563*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, M, N) 564*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor( 565*da0073e9SAndroid Build Coastguard Worker shape2, dtype=dtype, device=device, low=-1, high=1 566*da0073e9SAndroid Build Coastguard Worker ).expand(num_batches, N, O) 567*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 568*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 569*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 570*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 571*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 572*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 573*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 574*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 575*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 576*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) 577*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) 578*da0073e9SAndroid Build Coastguard Worker ref = torch.from_numpy( 579*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() 580*da0073e9SAndroid Build Coastguard Worker ).to(device=device, dtype=dtype) 581*da0073e9SAndroid Build Coastguard Worker out_tensor = torch.zeros_like(ref) 582*da0073e9SAndroid Build Coastguard Worker yield b1, b2, ref, out_tensor 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker for b1, b2, ref, out_tensor in generate_tensor(): 585*da0073e9SAndroid Build Coastguard Worker self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker def test_tensordot(self, device): 588*da0073e9SAndroid Build Coastguard Worker a = torch.arange(60.0, device=device).reshape(3, 4, 5) 589*da0073e9SAndroid Build Coastguard Worker b = torch.arange(24.0, device=device).reshape(4, 3, 2) 590*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() 591*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy( 592*da0073e9SAndroid Build Coastguard Worker np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1])) 593*da0073e9SAndroid Build Coastguard Worker ) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker cout = torch.zeros((5, 2), device=device) 597*da0073e9SAndroid Build Coastguard Worker torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() 598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cout) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, 5, device=device) 601*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 5, 6, 7, device=device) 602*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b, dims=2).cpu() 603*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): 606*da0073e9SAndroid Build Coastguard Worker torch.tensordot(a, b, dims=-1) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 609*da0073e9SAndroid Build Coastguard Worker c = torch.tensordot(a, b).cpu() 610*da0073e9SAndroid Build Coastguard Worker cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) 611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cn) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0) 614*da0073e9SAndroid Build Coastguard Worker an = torch.from_numpy( 615*da0073e9SAndroid Build Coastguard Worker np.tensordot( 616*da0073e9SAndroid Build Coastguard Worker np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0 617*da0073e9SAndroid Build Coastguard Worker ) 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, an) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 622*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float32: 1e-4}) 623*da0073e9SAndroid Build Coastguard Worker def test_1_sized_with_0_strided(self, device, dtype): 624*da0073e9SAndroid Build Coastguard Worker a = make_tensor((8, 1, 64), dtype=dtype, device=device) 625*da0073e9SAndroid Build Coastguard Worker a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) 626*da0073e9SAndroid Build Coastguard Worker b = make_tensor((8, 64, 512), dtype=dtype, device=device) 627*da0073e9SAndroid Build Coastguard Worker b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) 628*da0073e9SAndroid Build Coastguard Worker res = torch.bmm(a_strided, b_strided) 629*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to( 630*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype 631*da0073e9SAndroid Build Coastguard Worker ) 632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker def _select_broadcastable_dims(self, dims_full=None): 635*da0073e9SAndroid Build Coastguard Worker # select full dimensionality 636*da0073e9SAndroid Build Coastguard Worker if dims_full is None: 637*da0073e9SAndroid Build Coastguard Worker dims_full = [] 638*da0073e9SAndroid Build Coastguard Worker ndims = random.randint(1, 4) 639*da0073e9SAndroid Build Coastguard Worker dims_full = [random.randint(1, 8) for _ in range(ndims)] 640*da0073e9SAndroid Build Coastguard Worker else: 641*da0073e9SAndroid Build Coastguard Worker ndims = len(dims_full) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker # select actual dimensions for ops: 644*da0073e9SAndroid Build Coastguard Worker # larger: full ndims, individual sizes may be reduced 645*da0073e9SAndroid Build Coastguard Worker # smaller: possibly reduced ndims, sizes may be reduced 646*da0073e9SAndroid Build Coastguard Worker smaller_ndims = random.randint(1, ndims) 647*da0073e9SAndroid Build Coastguard Worker dims_small = [] 648*da0073e9SAndroid Build Coastguard Worker dims_large = [] 649*da0073e9SAndroid Build Coastguard Worker for i in range(ndims - 1, -1, -1): 650*da0073e9SAndroid Build Coastguard Worker j = random.randint(1, 3) 651*da0073e9SAndroid Build Coastguard Worker if j == 1: # no reduced singleton dimension 652*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 653*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 654*da0073e9SAndroid Build Coastguard Worker elif j == 2: # larger may have reduced singleton dimension 655*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 656*da0073e9SAndroid Build Coastguard Worker dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] 657*da0073e9SAndroid Build Coastguard Worker elif j == 3: # smaller may have reduced singleton dimension 658*da0073e9SAndroid Build Coastguard Worker ds = 1 659*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 660*da0073e9SAndroid Build Coastguard Worker dims_large = [dl] + dims_large 661*da0073e9SAndroid Build Coastguard Worker if len(dims_small) < smaller_ndims: 662*da0073e9SAndroid Build Coastguard Worker dims_small = [ds] + dims_small 663*da0073e9SAndroid Build Coastguard Worker return (dims_small, dims_large, dims_full) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker def test_broadcast_fused_matmul(self, device): 666*da0073e9SAndroid Build Coastguard Worker fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker for fn in fns: 669*da0073e9SAndroid Build Coastguard Worker batch_dim = random.randint(1, 8) 670*da0073e9SAndroid Build Coastguard Worker n_dim = random.randint(1, 8) 671*da0073e9SAndroid Build Coastguard Worker m_dim = random.randint(1, 8) 672*da0073e9SAndroid Build Coastguard Worker p_dim = random.randint(1, 8) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker def dims_full_for_fn(): 675*da0073e9SAndroid Build Coastguard Worker if fn == "baddbmm": 676*da0073e9SAndroid Build Coastguard Worker return ( 677*da0073e9SAndroid Build Coastguard Worker [batch_dim, n_dim, p_dim], 678*da0073e9SAndroid Build Coastguard Worker [batch_dim, n_dim, m_dim], 679*da0073e9SAndroid Build Coastguard Worker [batch_dim, m_dim, p_dim], 680*da0073e9SAndroid Build Coastguard Worker ) 681*da0073e9SAndroid Build Coastguard Worker elif fn == "addbmm": 682*da0073e9SAndroid Build Coastguard Worker return ( 683*da0073e9SAndroid Build Coastguard Worker [n_dim, p_dim], 684*da0073e9SAndroid Build Coastguard Worker [batch_dim, n_dim, m_dim], 685*da0073e9SAndroid Build Coastguard Worker [batch_dim, m_dim, p_dim], 686*da0073e9SAndroid Build Coastguard Worker ) 687*da0073e9SAndroid Build Coastguard Worker elif fn == "addmm": 688*da0073e9SAndroid Build Coastguard Worker return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) 689*da0073e9SAndroid Build Coastguard Worker elif fn == "addmv": 690*da0073e9SAndroid Build Coastguard Worker return ([n_dim], [n_dim, m_dim], [m_dim]) 691*da0073e9SAndroid Build Coastguard Worker elif fn == "addr": 692*da0073e9SAndroid Build Coastguard Worker return ([n_dim, m_dim], [n_dim], [m_dim]) 693*da0073e9SAndroid Build Coastguard Worker else: 694*da0073e9SAndroid Build Coastguard Worker raise AssertionError("unknown function") 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() 697*da0073e9SAndroid Build Coastguard Worker (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker t0_small = torch.randn(*t0_dims_small, device=device).float() 700*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(*t1_dims, device=device).float() 701*da0073e9SAndroid Build Coastguard Worker t2 = torch.randn(*t2_dims, device=device).float() 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker t0_full = t0_small.expand(*t0_dims_full).to(device) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker fntorch = getattr(torch, fn) 706*da0073e9SAndroid Build Coastguard Worker r0 = fntorch(t0_small, t1, t2) 707*da0073e9SAndroid Build Coastguard Worker r1 = fntorch(t0_full, t1, t2) 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r0, r1) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 711*da0073e9SAndroid Build Coastguard Worker def test_strided_mm_bmm(self, device, dtype): 712*da0073e9SAndroid Build Coastguard Worker # Tests strided view case with stride smaller than corresponding dimension size 713*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device) 714*da0073e9SAndroid Build Coastguard Worker new_shape = [2, 2, 2] 715*da0073e9SAndroid Build Coastguard Worker new_stride = [3, 1, 1] 716*da0073e9SAndroid Build Coastguard Worker sx = torch.as_strided(x, size=new_shape, stride=new_stride) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 719*da0073e9SAndroid Build Coastguard Worker np_fn = lambda x: np.matmul(x, x) # noqa: E731 720*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, sx) 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker torch_fn = lambda x: torch.mm(x, x) # noqa: E731 723*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, sx[0]) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker def test_mm_empty_inputs_mixed_dtype_errors(self, device): 726*da0073e9SAndroid Build Coastguard Worker a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) 727*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 20, dtype=torch.float32, device=device) 728*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 729*da0073e9SAndroid Build Coastguard Worker RuntimeError, "expected .* and .* to have the same dtype, but got:" 730*da0073e9SAndroid Build Coastguard Worker ): 731*da0073e9SAndroid Build Coastguard Worker torch.mm(a, b) 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker def test_matmul_45724(self, device): 734*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/45724 735*da0073e9SAndroid Build Coastguard Worker a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) 736*da0073e9SAndroid Build Coastguard Worker b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) 737*da0073e9SAndroid Build Coastguard Worker c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) 738*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half() 739*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, cpu_result) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker @dtypes( 743*da0073e9SAndroid Build Coastguard Worker torch.int16, 744*da0073e9SAndroid Build Coastguard Worker torch.int32, 745*da0073e9SAndroid Build Coastguard Worker torch.int64, 746*da0073e9SAndroid Build Coastguard Worker torch.float16, 747*da0073e9SAndroid Build Coastguard Worker torch.float32, 748*da0073e9SAndroid Build Coastguard Worker torch.float64, 749*da0073e9SAndroid Build Coastguard Worker ) 750*da0073e9SAndroid Build Coastguard Worker def test_baddbmm_input_dtypes_compatibility(self, device, dtype): 751*da0073e9SAndroid Build Coastguard Worker batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 752*da0073e9SAndroid Build Coastguard Worker batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) 753*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) 754*da0073e9SAndroid Build Coastguard Worker if dtype != torch.float32: 755*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): 756*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) 757*da0073e9SAndroid Build Coastguard Worker else: 758*da0073e9SAndroid Build Coastguard Worker out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) 759*da0073e9SAndroid Build Coastguard Worker y_ref = torch.bmm(batch1, batch2) 760*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) 761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, y_ref) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 764*da0073e9SAndroid Build Coastguard Worker def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): 765*da0073e9SAndroid Build Coastguard Worker for shape in [[3, 2, 2], [2, 20, 20]]: 766*da0073e9SAndroid Build Coastguard Worker mat1, mat2 = ( 767*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device) for _ in range(2) 768*da0073e9SAndroid Build Coastguard Worker ) 769*da0073e9SAndroid Build Coastguard Worker inputs = [ 770*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device), 771*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), 772*da0073e9SAndroid Build Coastguard Worker ] 773*da0073e9SAndroid Build Coastguard Worker outs = [ 774*da0073e9SAndroid Build Coastguard Worker None, 775*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device), 776*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), 777*da0073e9SAndroid Build Coastguard Worker ] 778*da0073e9SAndroid Build Coastguard Worker options = itertools.product(inputs, outs) 779*da0073e9SAndroid Build Coastguard Worker for input, out in options: 780*da0073e9SAndroid Build Coastguard Worker y_ref = torch.bmm(mat1, mat2) 781*da0073e9SAndroid Build Coastguard Worker y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) 782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_ref, y) 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 785*da0073e9SAndroid Build Coastguard Worker def test_addmm_sizes(self, device, dtype): 786*da0073e9SAndroid Build Coastguard Worker for m in [0, 1, 25]: 787*da0073e9SAndroid Build Coastguard Worker for n in [0, 1, 10]: 788*da0073e9SAndroid Build Coastguard Worker for k in [0, 1, 8]: 789*da0073e9SAndroid Build Coastguard Worker M = torch.randn(n, m, device=device).to(dtype) 790*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(n, k, device=device).to(dtype) 791*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(k, m, device=device).to(dtype) 792*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmm, M, m1, m2) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(n, k + 1, device=device).to(dtype) 795*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(k, m, device=device).to(dtype) 796*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 797*da0073e9SAndroid Build Coastguard Worker RuntimeError, 798*da0073e9SAndroid Build Coastguard Worker f"{n}x{k + 1}.*{k}x{m}", 799*da0073e9SAndroid Build Coastguard Worker lambda: torch.addmm(M, m1, m2), 800*da0073e9SAndroid Build Coastguard Worker ) 801*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 802*da0073e9SAndroid Build Coastguard Worker RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2) 803*da0073e9SAndroid Build Coastguard Worker ) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 806*da0073e9SAndroid Build Coastguard Worker { 807*da0073e9SAndroid Build Coastguard Worker torch.double: 1e-8, 808*da0073e9SAndroid Build Coastguard Worker torch.float: 1e-4, 809*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 5e-2, 810*da0073e9SAndroid Build Coastguard Worker torch.half: 5e-2, 811*da0073e9SAndroid Build Coastguard Worker torch.cfloat: 1e-4, 812*da0073e9SAndroid Build Coastguard Worker torch.cdouble: 1e-8, 813*da0073e9SAndroid Build Coastguard Worker } 814*da0073e9SAndroid Build Coastguard Worker ) 815*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16, torch.half) 816*da0073e9SAndroid Build Coastguard Worker def test_addmm_gelu(self, device, dtype): 817*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 820*da0073e9SAndroid Build Coastguard Worker { 821*da0073e9SAndroid Build Coastguard Worker torch.double: 1e-8, 822*da0073e9SAndroid Build Coastguard Worker torch.float: 1e-4, 823*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 5e-2, 824*da0073e9SAndroid Build Coastguard Worker torch.half: 5e-2, 825*da0073e9SAndroid Build Coastguard Worker torch.cfloat: 1e-4, 826*da0073e9SAndroid Build Coastguard Worker torch.cdouble: 1e-8, 827*da0073e9SAndroid Build Coastguard Worker } 828*da0073e9SAndroid Build Coastguard Worker ) 829*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.bfloat16, torch.half) 830*da0073e9SAndroid Build Coastguard Worker def test_addmm_relu(self, device, dtype): 831*da0073e9SAndroid Build Coastguard Worker self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.bfloat16, torch.half) 834*da0073e9SAndroid Build Coastguard Worker def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): 835*da0073e9SAndroid Build Coastguard Worker # tests (o, s)*(s). o is output size, s is summed size. 836*da0073e9SAndroid Build Coastguard Worker o = 5 837*da0073e9SAndroid Build Coastguard Worker s = 3 838*da0073e9SAndroid Build Coastguard Worker a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) 839*da0073e9SAndroid Build Coastguard Worker x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) 840*da0073e9SAndroid Build Coastguard Worker y_data = torch.ones(o, device=device, dtype=dtype) 841*da0073e9SAndroid Build Coastguard Worker control = torch.tensor( 842*da0073e9SAndroid Build Coastguard Worker [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype 843*da0073e9SAndroid Build Coastguard Worker ) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker def _test(row_major, incx, incy, lda_tail): 846*da0073e9SAndroid Build Coastguard Worker if row_major: 847*da0073e9SAndroid Build Coastguard Worker a_storage = torch.full( 848*da0073e9SAndroid Build Coastguard Worker (o, s + lda_tail), float("nan"), device=device, dtype=dtype 849*da0073e9SAndroid Build Coastguard Worker ) 850*da0073e9SAndroid Build Coastguard Worker else: 851*da0073e9SAndroid Build Coastguard Worker a_storage = torch.full( 852*da0073e9SAndroid Build Coastguard Worker (s, o + lda_tail), float("nan"), device=device, dtype=dtype 853*da0073e9SAndroid Build Coastguard Worker ).permute(1, 0) 854*da0073e9SAndroid Build Coastguard Worker a = a_storage[:o, :s].copy_(a_data) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype) 857*da0073e9SAndroid Build Coastguard Worker x = x_storage[:, 0].copy_(x_data) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype) 860*da0073e9SAndroid Build Coastguard Worker y = y_storage[:, 0].copy_(y_data) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmv, y, a, x) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker for row_major, incx, incy, lda_tail in itertools.product( 865*da0073e9SAndroid Build Coastguard Worker (False, True), (1, 2), (1, 2), (0, 1) 866*da0073e9SAndroid Build Coastguard Worker ): 867*da0073e9SAndroid Build Coastguard Worker _test(row_major, incx, incy, lda_tail) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker @precisionOverride( 870*da0073e9SAndroid Build Coastguard Worker { 871*da0073e9SAndroid Build Coastguard Worker torch.double: 1e-8, 872*da0073e9SAndroid Build Coastguard Worker torch.float: 1e-4, 873*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: 0.6, 874*da0073e9SAndroid Build Coastguard Worker torch.half: 1e-1, 875*da0073e9SAndroid Build Coastguard Worker torch.cfloat: 1e-4, 876*da0073e9SAndroid Build Coastguard Worker torch.cdouble: 1e-8, 877*da0073e9SAndroid Build Coastguard Worker } 878*da0073e9SAndroid Build Coastguard Worker ) 879*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half, torch.float32) 880*da0073e9SAndroid Build Coastguard Worker def test_corner_cases_of_cublasltmatmul(self, device, dtype): 881*da0073e9SAndroid Build Coastguard Worker # common case 882*da0073e9SAndroid Build Coastguard Worker M = torch.randn(128, device=device).to(dtype) 883*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(2048, 2400, device=device).to(dtype) 884*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(128, 2400, device=device).to(dtype) 885*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(m1, m2, M) 886*da0073e9SAndroid Build Coastguard Worker # Ntrans_B has ld >> rows 887*da0073e9SAndroid Build Coastguard Worker m1 = torch.rand([128, 2400]).to(dtype).to(device).t() 888*da0073e9SAndroid Build Coastguard Worker m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] 889*da0073e9SAndroid Build Coastguard Worker M = torch.rand([128]).to(dtype).to(device) 890*da0073e9SAndroid Build Coastguard Worker torch.addmm(M, m2.t(), m1) 891*da0073e9SAndroid Build Coastguard Worker # trans_A has ld >> rows 892*da0073e9SAndroid Build Coastguard Worker m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() 893*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(2048, 2400, device=device).to(dtype) 894*da0073e9SAndroid Build Coastguard Worker M = torch.rand([128]).to(dtype).to(device) 895*da0073e9SAndroid Build Coastguard Worker torch.addmm(M, m2, m1) 896*da0073e9SAndroid Build Coastguard Worker # large tensor dim > 65535 897*da0073e9SAndroid Build Coastguard Worker M = torch.randn(16, device=device).to(dtype) 898*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(32, 131071, device=device).to(dtype) 899*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(16, 131071, device=device).to(dtype) 900*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(m1, m2, M) 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker def test_blas_empty(self, device): 903*da0073e9SAndroid Build Coastguard Worker def fn(torchfn, *args, test_out=False, **kwargs): 904*da0073e9SAndroid Build Coastguard Worker def call_torch_fn(*args, **kwargs): 905*da0073e9SAndroid Build Coastguard Worker return torchfn( 906*da0073e9SAndroid Build Coastguard Worker *tuple( 907*da0073e9SAndroid Build Coastguard Worker torch.randn(shape, device=device) 908*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, tuple) 909*da0073e9SAndroid Build Coastguard Worker else shape 910*da0073e9SAndroid Build Coastguard Worker for shape in args 911*da0073e9SAndroid Build Coastguard Worker ), 912*da0073e9SAndroid Build Coastguard Worker **kwargs, 913*da0073e9SAndroid Build Coastguard Worker ) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker result = call_torch_fn(*args, **kwargs) 916*da0073e9SAndroid Build Coastguard Worker if not test_out: 917*da0073e9SAndroid Build Coastguard Worker return result 918*da0073e9SAndroid Build Coastguard Worker else: 919*da0073e9SAndroid Build Coastguard Worker out = torch.full_like(result, math.nan) 920*da0073e9SAndroid Build Coastguard Worker out1 = call_torch_fn(*args, **kwargs, out=out) 921*da0073e9SAndroid Build Coastguard Worker return out 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker # mm, addmm 924*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) 925*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) 926*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) 927*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) 928*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 929*da0073e9SAndroid Build Coastguard Worker torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)) 930*da0073e9SAndroid Build Coastguard Worker ) 931*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 932*da0073e9SAndroid Build Coastguard Worker torch.zeros((5, 6), device=device), 933*da0073e9SAndroid Build Coastguard Worker fn(torch.mm, (5, 0), (0, 6), test_out=True), 934*da0073e9SAndroid Build Coastguard Worker ) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) 937*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape) 938*da0073e9SAndroid Build Coastguard Worker t = torch.randn((5, 6), device=device) 939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) 940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker # mv, addmv 943*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) 944*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) 945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) 946*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 947*da0073e9SAndroid Build Coastguard Worker torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True) 948*da0073e9SAndroid Build Coastguard Worker ) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) 951*da0073e9SAndroid Build Coastguard Worker t = torch.randn((3,), device=device) 952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) 953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker # bmm, baddbmm 956*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) 957*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) 958*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) 959*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 960*da0073e9SAndroid Build Coastguard Worker torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)) 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 963*da0073e9SAndroid Build Coastguard Worker torch.zeros((3, 5, 6), device=device), 964*da0073e9SAndroid Build Coastguard Worker fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True), 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 968*da0073e9SAndroid Build Coastguard Worker (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 971*da0073e9SAndroid Build Coastguard Worker (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape 972*da0073e9SAndroid Build Coastguard Worker ) 973*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 974*da0073e9SAndroid Build Coastguard Worker (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape 975*da0073e9SAndroid Build Coastguard Worker ) 976*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 977*da0073e9SAndroid Build Coastguard Worker (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape 978*da0073e9SAndroid Build Coastguard Worker ) 979*da0073e9SAndroid Build Coastguard Worker c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 981*da0073e9SAndroid Build Coastguard Worker -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2) 982*da0073e9SAndroid Build Coastguard Worker ) # Issue #33467 983*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 984*da0073e9SAndroid Build Coastguard Worker -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True) 985*da0073e9SAndroid Build Coastguard Worker ) # Issue #33467 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker # addbmm 988*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) 989*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) 990*da0073e9SAndroid Build Coastguard Worker t = torch.randn((5, 6), device=device) 991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) 992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker # matmul 995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,))) 996*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 997*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, device=device), 998*da0073e9SAndroid Build Coastguard Worker fn(torch.matmul, (0,), (0,), test_out=True), 999*da0073e9SAndroid Build Coastguard Worker ) 1000*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) 1001*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) 1002*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) 1003*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1004*da0073e9SAndroid Build Coastguard Worker torch.zeros((5, 3, 4), device=device), 1005*da0073e9SAndroid Build Coastguard Worker fn(torch.matmul, (5, 3, 0), (5, 0, 4)), 1006*da0073e9SAndroid Build Coastguard Worker ) 1007*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1008*da0073e9SAndroid Build Coastguard Worker torch.zeros((5, 3, 4), device=device), 1009*da0073e9SAndroid Build Coastguard Worker fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True), 1010*da0073e9SAndroid Build Coastguard Worker ) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker # dot 1013*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,))) 1014*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1015*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True) 1016*da0073e9SAndroid Build Coastguard Worker ) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker def test_large_bmm_backward(self, device): 1019*da0073e9SAndroid Build Coastguard Worker A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT 1020*da0073e9SAndroid Build Coastguard Worker B = torch.randn([1, 1024, 65536], device=device, requires_grad=True) 1021*da0073e9SAndroid Build Coastguard Worker G = torch.randn([1024, 2, 65536], device=device) 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 1024*da0073e9SAndroid Build Coastguard Worker (A @ B).backward(G) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker def test_large_bmm_mm_backward(self, device): 1027*da0073e9SAndroid Build Coastguard Worker A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT 1028*da0073e9SAndroid Build Coastguard Worker B = torch.randn([1024, 65536], device=device, requires_grad=True) 1029*da0073e9SAndroid Build Coastguard Worker G = torch.randn([1024, 2, 65536], device=device) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM 1032*da0073e9SAndroid Build Coastguard Worker (A @ B).backward(G) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker def check_single_matmul(self, x, y): 1035*da0073e9SAndroid Build Coastguard Worker def assertEqual(answer, expected): 1036*da0073e9SAndroid Build Coastguard Worker if x.dtype.is_floating_point or x.dtype.is_complex: 1037*da0073e9SAndroid Build Coastguard Worker k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix 1038*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1039*da0073e9SAndroid Build Coastguard Worker answer, 1040*da0073e9SAndroid Build Coastguard Worker expected, 1041*da0073e9SAndroid Build Coastguard Worker msg=f"{x.shape} x {y.shape} = {answer.shape}", 1042*da0073e9SAndroid Build Coastguard Worker atol=k * 5e-5, 1043*da0073e9SAndroid Build Coastguard Worker rtol=1e-4, 1044*da0073e9SAndroid Build Coastguard Worker ) 1045*da0073e9SAndroid Build Coastguard Worker else: 1046*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1047*da0073e9SAndroid Build Coastguard Worker answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}" 1048*da0073e9SAndroid Build Coastguard Worker ) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker # test x @ y 1051*da0073e9SAndroid Build Coastguard Worker expected = np.matmul(x.cpu(), y.cpu()) 1052*da0073e9SAndroid Build Coastguard Worker ans = torch.matmul(x, y) 1053*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 1054*da0073e9SAndroid Build Coastguard Worker assertEqual(ans, expected) 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker # test out 1057*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(ans) 1058*da0073e9SAndroid Build Coastguard Worker ans = torch.matmul(x, y, out=out) 1059*da0073e9SAndroid Build Coastguard Worker self.assertIs(ans, out) 1060*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ans.is_contiguous()) 1061*da0073e9SAndroid Build Coastguard Worker assertEqual(ans, expected) 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): 1064*da0073e9SAndroid Build Coastguard Worker """ 1065*da0073e9SAndroid Build Coastguard Worker Generates sequences of tuples (x, y) of with size(x) = x_dim and 1066*da0073e9SAndroid Build Coastguard Worker size(y) <= y_dim that are compatible wrt. matmul 1067*da0073e9SAndroid Build Coastguard Worker """ 1068*da0073e9SAndroid Build Coastguard Worker assert x_dim >= 1 1069*da0073e9SAndroid Build Coastguard Worker assert y_dim >= 2 1070*da0073e9SAndroid Build Coastguard Worker x = x_dim 1071*da0073e9SAndroid Build Coastguard Worker for y in range(1, y_dim + 1): 1072*da0073e9SAndroid Build Coastguard Worker for batch, mn in product( 1073*da0073e9SAndroid Build Coastguard Worker product(range(batch_size), repeat=max(x - 2, y - 2, 0)), 1074*da0073e9SAndroid Build Coastguard Worker product(range(matrix_size), repeat=min(y, 2)), 1075*da0073e9SAndroid Build Coastguard Worker ): 1076*da0073e9SAndroid Build Coastguard Worker if x == 1: 1077*da0073e9SAndroid Build Coastguard Worker size_x = mn[:1] 1078*da0073e9SAndroid Build Coastguard Worker size_y = batch + mn 1079*da0073e9SAndroid Build Coastguard Worker yield size_x, size_y 1080*da0073e9SAndroid Build Coastguard Worker else: 1081*da0073e9SAndroid Build Coastguard Worker for k in range(matrix_size): 1082*da0073e9SAndroid Build Coastguard Worker size_x = (k,) + mn[:1] 1083*da0073e9SAndroid Build Coastguard Worker if x > 2: 1084*da0073e9SAndroid Build Coastguard Worker size_x = batch[-(x - 2) :] + size_x 1085*da0073e9SAndroid Build Coastguard Worker size_y = mn 1086*da0073e9SAndroid Build Coastguard Worker if y > 2: 1087*da0073e9SAndroid Build Coastguard Worker size_y = batch[-(y - 2) :] + size_y 1088*da0073e9SAndroid Build Coastguard Worker yield size_x, size_y 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1091*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_1d_Nd(self, device, dtype): 1092*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product( 1095*da0073e9SAndroid Build Coastguard Worker self.gen_sizes_matmul(1), (True, False), (True, False) 1096*da0073e9SAndroid Build Coastguard Worker ): 1097*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 1098*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 1099*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1102*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_2d_Nd(self, device, dtype): 1103*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product( 1106*da0073e9SAndroid Build Coastguard Worker self.gen_sizes_matmul(2), (True, False), (True, False) 1107*da0073e9SAndroid Build Coastguard Worker ): 1108*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 1109*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 1110*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1113*da0073e9SAndroid Build Coastguard Worker def test_matmul_small_brute_force_3d_Nd(self, device, dtype): 1114*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker for (size_x, size_y), nctg_x, nctg_y in product( 1117*da0073e9SAndroid Build Coastguard Worker self.gen_sizes_matmul(3), (True, False), (True, False) 1118*da0073e9SAndroid Build Coastguard Worker ): 1119*da0073e9SAndroid Build Coastguard Worker x = make_arg(size_x, noncontiguous=nctg_x) 1120*da0073e9SAndroid Build Coastguard Worker y = make_arg(size_y, noncontiguous=nctg_y) 1121*da0073e9SAndroid Build Coastguard Worker self.check_single_matmul(x, y) 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1124*da0073e9SAndroid Build Coastguard Worker def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): 1125*da0073e9SAndroid Build Coastguard Worker a = torch.empty( 1126*da0073e9SAndroid Build Coastguard Worker (256, 512), device=device, dtype=dtype, requires_grad=True 1127*da0073e9SAndroid Build Coastguard Worker ).unsqueeze(0) 1128*da0073e9SAndroid Build Coastguard Worker b = torch.empty( 1129*da0073e9SAndroid Build Coastguard Worker (4, 128, 512), device=device, dtype=dtype, requires_grad=True 1130*da0073e9SAndroid Build Coastguard Worker ).transpose(-1, -2) 1131*da0073e9SAndroid Build Coastguard Worker c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Worker torch.matmul(a.detach(), b.detach(), out=c) 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1136*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1137*da0073e9SAndroid Build Coastguard Worker "functions with out=... arguments don't support automatic differentiation", 1138*da0073e9SAndroid Build Coastguard Worker ): 1139*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1142*da0073e9SAndroid Build Coastguard Worker torch.matmul(a, b, out=c) 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1148*da0073e9SAndroid Build Coastguard Worker run_tests() 1149