1# Owner(s): ["module: nn"] 2import math 3import copy 4 5import torch 6from torch.testing._internal.common_device_type import ( 7 dtypes, 8 dtypesIfCUDA, 9 instantiate_device_type_tests, 10 onlyCUDA, 11 skipMeta, 12) 13from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM 14 15class TestMHADeviceType(TestCase): 16 @torch.no_grad() 17 def _test_transform_bias_rescale_qkv_impl( 18 self, device, dtype, use_nt, use_padding=False 19 ): 20 tests = [ 21 (64, 4, 16, 8), 22 # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8 23 (24, 2, 4, 2), 24 # Make sure CUDA can handle small input sizes 25 (2, 2, 2, 2), 26 # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4, 27 # causes alignment issues 28 (24, 4, 4, 2), 29 (48, 4, 16, 8), 30 ] 31 for (embed_dim, num_heads, bs, sl) in tests: 32 with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl): 33 torch.manual_seed(9343) 34 dense_x = x = ( 35 torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10 36 ) 37 if use_padding: 38 x[0][-1] = torch.full(x[0][-1].shape, float("-Inf")) 39 if use_nt: 40 xs = list(torch.unbind(x)) 41 if use_padding: 42 xs[0] = xs[0][:-1] 43 x = torch.nested.nested_tensor(xs, device=device, dtype=dtype) 44 qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) 45 46 # We have to use inference_mode here because q/k/v are 47 # all views of the same Tensor, which autograd doesn't 48 # like. This is fine because this function is only 49 # exposed to Python for purposes of writing this test. 50 with torch.inference_mode(): 51 (q, k, v) = torch._transform_bias_rescale_qkv( 52 x, qkv.bias, num_heads=num_heads 53 ) 54 55 def simple_transform_bias_rescale_qkv(qkv, bias): 56 (q, k, v) = torch.split(qkv, embed_dim, dim=-1) 57 (q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1) 58 59 def embiggen(x): 60 if not use_nt: 61 return x 62 b, t, d = x.size() 63 t = t + (8 - t % 8) % 8 64 newsize = (b, t, d) 65 new_x = torch.zeros(newsize, device=device, dtype=dtype) 66 new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x 67 return new_x 68 return tuple( 69 embiggen(x).reshape( 70 (bs, -1, num_heads, embed_dim // num_heads) 71 ).transpose(2, 1) 72 for x in ( 73 (q + q_bias) / math.sqrt(embed_dim // num_heads), 74 (k + k_bias), 75 (v + v_bias), 76 ) 77 ) 78 79 correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv( 80 dense_x, qkv.bias 81 ) 82 if use_nt and use_padding: 83 for t in (correct_q, correct_k, correct_v): 84 t[t == float("-Inf")] = 0 85 86 self.assertEqual(q.size(), correct_q.size()) 87 torch.testing.assert_close(q, correct_q) 88 torch.testing.assert_close(k, correct_k) 89 torch.testing.assert_close(v, correct_v) 90 91 @dtypesIfCUDA(torch.float) 92 @dtypes(torch.float) 93 @skipMeta 94 def test_transform_bias_rescale_qkv(self, device, dtype): 95 for use_padding in (False, True): 96 with self.subTest(use_padding=use_padding): 97 self._test_transform_bias_rescale_qkv_impl( 98 device, dtype, use_nt=False, use_padding=use_padding 99 ) 100 101 @dtypesIfCUDA(torch.float) 102 @dtypes(torch.float) 103 @skipMeta 104 @onlyCUDA 105 def test_transform_bias_rescale_qkv_nested(self, device, dtype): 106 for use_padding in (False, True): 107 with self.subTest(use_padding=use_padding): 108 self._test_transform_bias_rescale_qkv_impl( 109 device, dtype, use_nt=True, use_padding=use_padding 110 ) 111 112 def _test_multihead_attention_impl( 113 self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False 114 ): 115 embed_dim = 64 116 num_heads = 4 117 bs = 16 118 sl = 8 119 120 q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 121 if use_padding: 122 if pad_all: 123 for q_i in q: 124 q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) 125 mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) 126 for mask_i in mask: 127 mask_i[-1] = True 128 else: 129 q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) 130 mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) 131 mask[0][-1] = True 132 if mode == "self": 133 k = q 134 v = q 135 elif mode == "encdec": 136 k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 137 v = k 138 elif mode == "generic": 139 k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 140 v = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 141 else: 142 self.fail(f"invalid mode `{mode}`!") 143 144 qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=torch.float32) 145 native_qkv = copy.deepcopy(qkv).to(dtype=dtype) 146 147 proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32) 148 native_proj = copy.deepcopy(proj).to(dtype=dtype) 149 150 pt = torch.nn.MultiheadAttention( 151 embed_dim, num_heads, batch_first=True, device=device, dtype=torch.float32 152 ) 153 154 pt.in_proj_weight = qkv.weight 155 pt.in_proj_bias = qkv.bias 156 pt.out_proj.weight = proj.weight 157 pt.out_proj.bias = proj.bias 158 159 class NativeMHA(torch.nn.Module): 160 def __init__(self, embed_dim, num_heads, qkv, proj): 161 super().__init__() 162 self.qkv = qkv 163 self.proj = proj 164 self.embed_dim = embed_dim 165 self.num_heads = num_heads 166 167 def forward(self, q, k, v, key_padding_mask): 168 return torch._native_multi_head_attention( 169 q, 170 k, 171 v, 172 self.embed_dim, 173 self.num_heads, 174 self.qkv.weight, 175 self.qkv.bias, 176 self.proj.weight, 177 self.proj.bias, 178 key_padding_mask, 179 need_weights=need_weights, 180 average_attn_weights=average_attn_weights, 181 mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask 182 ) 183 184 npt = NativeMHA( 185 embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj 186 ).to(dtype) 187 188 if device == "cuda": 189 pt = pt.cuda() 190 npt = npt.cuda() 191 192 ypt, weight_pt = pt( 193 q, 194 k, 195 v, 196 need_weights=need_weights, 197 average_attn_weights=average_attn_weights, 198 key_padding_mask=mask if use_padding else None, 199 ) 200 if use_nt: 201 qs = list(torch.unbind(q)) 202 if use_padding: 203 if pad_all: 204 qs = [x[:-1] for x in qs] 205 else: 206 qs[0] = qs[0][:-1] 207 q = torch.nested.nested_tensor(qs, device=device, dtype=dtype) 208 if mode == "self": 209 k = v = q 210 elif mode == "encdec": 211 k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype) 212 v = k 213 else: 214 k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype) 215 v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype) 216 217 native_q = q.to(dtype=dtype) 218 native_k = k.to(dtype=dtype) 219 native_v = v.to(dtype=dtype) 220 221 ynpt, weight_npt = npt( 222 native_q, native_k, native_v, key_padding_mask=mask if use_padding and not use_nt else None 223 ) 224 if use_nt: 225 ynpt = ynpt.to_padded_tensor(0) 226 if pad_all: 227 ynpt_final = torch.zeros_like(ypt) 228 ynpt_final[:, :ynpt.shape[1], :] = ynpt 229 ynpt = ynpt_final 230 231 def do_pad_all(tensors): 232 for t in tensors: 233 for t_i in t: 234 t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype) 235 236 # PyTorch implementation returns non-zero junk in the padding 237 # locations; overwrite it so that the comparison works out. 238 if use_padding: 239 ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype) 240 ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype) 241 if pad_all: 242 do_pad_all((ypt, ynpt)) 243 # Zero the last row of each TxT weight matrix 244 if need_weights: 245 if average_attn_weights: 246 weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype) 247 weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype) 248 if pad_all: 249 do_pad_all((weight_pt, weight_npt)) 250 else: 251 for nh in range(num_heads): 252 weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype) 253 weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype) 254 255 if dtype == torch.half: 256 torch.testing.assert_close(ypt, ynpt.to(torch.float32), atol=1e-3, rtol=1e-3) 257 else: 258 # High rtol seems necessary for 259 # test_native_multihead_attention_cpu_float32 on Windows, 260 # otherwise 2e-4 would likely be fine. 261 torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3) 262 263 if need_weights: 264 torch.testing.assert_close(weight_pt, weight_npt.to(torch.float32), atol=5e-4, rtol=5e-4) 265 else: 266 self.assertEqual(weight_pt, weight_npt) 267 268 @dtypesIfCUDA(torch.float, torch.half) 269 @dtypes(torch.float) 270 @skipMeta 271 @parametrize("use_nt", [False, True]) 272 @parametrize("use_padding, pad_all", [(False, False), (True, False), (True, True)]) 273 @parametrize("need_weights", [False]) 274 @parametrize("average_attn_weights", [False, True]) 275 @parametrize("fused", [False, True]) 276 @torch.no_grad() 277 def test_native_multihead_self_attention(self, device, dtype, use_nt, 278 need_weights, average_attn_weights, use_padding, pad_all, fused): 279 if TEST_WITH_ROCM: 280 if use_nt: 281 self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") 282 if use_padding and not pad_all and fused: 283 self.skipTest("Large numerical errors on ROCM to investigate.") 284 for need_weights in (False, not pad_all): 285 with self.subTest(use_padding=use_padding, pad_all=pad_all, 286 use_nt=use_nt, need_weights=need_weights, 287 average_attn_weights=average_attn_weights): 288 with torch.backends.cuda.sdp_kernel( 289 enable_flash=False, enable_mem_efficient=False 290 ) if not fused else torch.backends.cuda.sdp_kernel( 291 enable_flash=True, enable_mem_efficient=True 292 ): 293 self._test_multihead_attention_impl( 294 device, 295 dtype, 296 "self", 297 use_nt=use_nt, 298 use_padding=use_padding, 299 pad_all=pad_all, 300 need_weights=need_weights, 301 average_attn_weights=average_attn_weights, 302 ) 303 304 @dtypesIfCUDA(torch.float, torch.half) 305 @dtypes(torch.float) 306 @skipMeta 307 @torch.no_grad() 308 def test_native_multihead_encoder_decoder_attention(self, device, dtype): 309 self._test_multihead_attention_impl( 310 device, 311 dtype, 312 "encdec", 313 use_nt=False, 314 need_weights=False, 315 average_attn_weights=False, 316 ) 317 318 @dtypesIfCUDA(torch.float, torch.half) 319 @dtypes(torch.float) 320 @skipMeta 321 @torch.no_grad() 322 def test_native_multihead_attention(self, device, dtype): 323 self._test_multihead_attention_impl( 324 device, 325 dtype, 326 "generic", 327 use_nt=False, 328 need_weights=False, 329 average_attn_weights=False, 330 ) 331 332 333instantiate_device_type_tests(TestMHADeviceType, globals()) 334 335if __name__ == "__main__": 336 run_tests() 337