1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10 11import torch 12import torch.nn.functional as F 13 14from .sdpa_with_kv_cache import custom_ops_lib # noqa 15 16 17def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len): 18 q = q.transpose(1, 2) 19 k_cache[:, start_pos : start_pos + seq_len, :, :] = k 20 v_cache[:, start_pos : start_pos + seq_len, :, :] = v 21 sliced_k_cache = k_cache[:, : start_pos + seq_len, :, :] 22 sliced_v_cache = v_cache[:, : start_pos + seq_len, :, :] 23 sliced_k_cache = sliced_k_cache.transpose(1, 2) 24 sliced_v_cache = sliced_v_cache.transpose(1, 2) 25 26 num_heads_q = q.size(1) 27 num_heads_kv = sliced_k_cache.size(1) 28 if num_heads_q != num_heads_kv: 29 assert ( 30 num_heads_q % num_heads_kv == 0 31 ), f"{num_heads_q} not divisible by {num_heads_kv}" 32 n_reps = num_heads_q // num_heads_kv 33 if n_reps > 1: 34 sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1) 35 sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1) 36 out = F.scaled_dot_product_attention( 37 q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask 38 ) 39 out = out.transpose(1, 2) 40 return out 41 42 43class SDPATest(unittest.TestCase): 44 45 def setUp(self): 46 torch.manual_seed(42) 47 self.k_cache = torch.zeros((1, 10, 8, 4)) 48 self.v_cache = torch.zeros((1, 10, 8, 4)) 49 self.mask = torch.full( 50 (10, 10), 51 float("-inf"), 52 ) 53 self.mask = torch.triu(self.mask, diagonal=1) 54 self.use_mask_with_custom_op = False 55 self.is_causal = False 56 57 def test_sdpa_with_cache_no_mqa_1(self): 58 q = torch.rand((1, 1, 8, 4)) 59 k = torch.rand((1, 1, 8, 4)) 60 v = torch.rand((1, 1, 8, 4)) 61 start_pos = 0 62 seq_len = q.size(1) 63 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 64 attn_mask = attn_mask[:, : start_pos + seq_len] 65 ref_output = _sdpa_with_kv_cache_ref( 66 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 67 ) 68 if self.use_mask_with_custom_op: 69 attn_mask = attn_mask.contiguous() 70 op_output = torch.ops.llama.sdpa_with_kv_cache( 71 q, 72 k, 73 v, 74 self.k_cache, 75 self.v_cache, 76 start_pos, 77 seq_len, 78 attn_mask, 79 0, 80 False, 81 ) 82 else: 83 op_output = torch.ops.llama.sdpa_with_kv_cache( 84 q, 85 k, 86 v, 87 self.k_cache, 88 self.v_cache, 89 start_pos, 90 seq_len, 91 None, 92 0, 93 self.is_causal, 94 ) 95 self.assertTrue(torch.allclose(ref_output, op_output)) 96 97 def test_sdpa_with_cache_no_mqa_2(self): 98 q = torch.rand((1, 1, 8, 4)) 99 k = torch.rand((1, 1, 8, 4)) 100 v = torch.rand((1, 1, 8, 4)) 101 start_pos = 1 102 seq_len = q.size(1) 103 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 104 attn_mask = attn_mask[:, : start_pos + seq_len] 105 106 ref_output = _sdpa_with_kv_cache_ref( 107 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 108 ) 109 if self.use_mask_with_custom_op: 110 attn_mask = attn_mask.contiguous() 111 op_output = torch.ops.llama.sdpa_with_kv_cache( 112 q, 113 k, 114 v, 115 self.k_cache, 116 self.v_cache, 117 start_pos, 118 seq_len, 119 attn_mask, 120 0, 121 False, 122 ) 123 else: 124 op_output = torch.ops.llama.sdpa_with_kv_cache( 125 q, 126 k, 127 v, 128 self.k_cache, 129 self.v_cache, 130 start_pos, 131 seq_len, 132 None, 133 0, 134 self.is_causal, 135 ) 136 137 self.assertTrue(torch.allclose(ref_output, op_output)) 138 139 def test_sdpa_with_cache_no_mqa_3(self): 140 q = torch.rand((1, 1, 8, 4)) 141 k = torch.rand((1, 1, 8, 4)) 142 v = torch.rand((1, 1, 8, 4)) 143 start_pos = 2 144 seq_len = q.size(1) 145 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 146 attn_mask = attn_mask[:, : start_pos + seq_len] 147 148 ref_output = _sdpa_with_kv_cache_ref( 149 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 150 ) 151 if self.use_mask_with_custom_op: 152 attn_mask = attn_mask.contiguous() 153 op_output = torch.ops.llama.sdpa_with_kv_cache( 154 q, 155 k, 156 v, 157 self.k_cache, 158 self.v_cache, 159 start_pos, 160 seq_len, 161 attn_mask, 162 0, 163 False, 164 ) 165 else: 166 op_output = torch.ops.llama.sdpa_with_kv_cache( 167 q, 168 k, 169 v, 170 self.k_cache, 171 self.v_cache, 172 start_pos, 173 seq_len, 174 None, 175 0, 176 self.is_causal, 177 ) 178 self.assertTrue(torch.allclose(ref_output, op_output)) 179 180 def test_sdpa_with_cache_no_mqa_4(self): 181 q = torch.rand((1, 1, 8, 4)) 182 k = torch.rand((1, 1, 8, 4)) 183 v = torch.rand((1, 1, 8, 4)) 184 start_pos = 3 185 seq_len = q.size(1) 186 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 187 attn_mask = attn_mask[:, : start_pos + seq_len] 188 189 ref_output = _sdpa_with_kv_cache_ref( 190 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 191 ) 192 if self.use_mask_with_custom_op: 193 attn_mask = attn_mask.contiguous() 194 op_output = torch.ops.llama.sdpa_with_kv_cache( 195 q, 196 k, 197 v, 198 self.k_cache, 199 self.v_cache, 200 start_pos, 201 seq_len, 202 attn_mask, 203 0, 204 False, 205 ) 206 else: 207 op_output = torch.ops.llama.sdpa_with_kv_cache( 208 q, 209 k, 210 v, 211 self.k_cache, 212 self.v_cache, 213 start_pos, 214 seq_len, 215 None, 216 0, 217 self.is_causal, 218 ) 219 self.assertTrue(torch.allclose(ref_output, op_output)) 220 221 222class SDPAWithAttentionMaskTest(SDPATest): 223 224 def setUp(self): 225 SDPATest.setUp(self) 226 self.mask = torch.full( 227 (10, 10), 228 100.642, 229 ) 230 self.use_mask_with_custom_op = True 231 232 233class SDPAWithCausalTest(SDPATest): 234 235 def setUp(self): 236 SDPATest.setUp(self) 237 self.is_causal = True 238 239 240class SDPAWithDynamicShape(unittest.TestCase): 241 242 def setUp(self): 243 torch.manual_seed(42) 244 self.k_cache = torch.zeros((1, 10, 8, 4)) 245 self.v_cache = torch.zeros((1, 10, 8, 4)) 246 self.mask = torch.full( 247 (10, 10), 248 float("-inf"), 249 ) 250 self.mask = torch.triu(self.mask, diagonal=1) 251 self.use_mask_with_custom_op = False 252 self.is_causal = False 253 254 def test_sdpa_with_cache_dynamic_shape_0(self): 255 q = torch.rand((1, 4, 8, 4)) 256 k = torch.rand((1, 4, 8, 4)) 257 v = torch.rand((1, 4, 8, 4)) 258 seq_len = q.size(1) 259 start_pos = 0 260 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 261 attn_mask = attn_mask[:, : start_pos + seq_len] 262 263 ref_output = _sdpa_with_kv_cache_ref( 264 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 265 ) 266 267 op_output = torch.ops.llama.sdpa_with_kv_cache( 268 q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True 269 ) 270 self.assertTrue(torch.allclose(ref_output, op_output)) 271 272 def test_sdpa_with_cache_dynamic_shape_2(self): 273 q = torch.rand((1, 3, 8, 4)) 274 k = torch.rand((1, 3, 8, 4)) 275 v = torch.rand((1, 3, 8, 4)) 276 seq_len = q.size(1) 277 start_pos = 2 278 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 279 attn_mask = attn_mask[:, : start_pos + seq_len] 280 281 ref_output = _sdpa_with_kv_cache_ref( 282 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 283 ) 284 285 op_output = torch.ops.llama.sdpa_with_kv_cache( 286 q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True 287 ) 288 self.assertTrue(torch.allclose(ref_output, op_output)) 289 290 @unittest.skip("This test will expect failure but runtime is not bubbling it up.") 291 def test_sdpa_with_cache_dynamic_shape_4(self): 292 q = torch.rand((1, 11, 8, 4)) 293 k = torch.rand((1, 11, 8, 4)) 294 v = torch.rand((1, 11, 8, 4)) 295 seq_len = q.size(1) 296 start_pos = 4 297 298 torch.ops.llama.sdpa_with_kv_cache( 299 q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True 300 ) 301 302 303class SDPATestWithMQA(unittest.TestCase): 304 305 def setup_caches(self): 306 self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) 307 self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4)) 308 309 def setUp(self): 310 torch.manual_seed(42) 311 self.n_heads_kv = 4 312 self.n_heads_q = 8 313 self.setup_caches() 314 self.mask = torch.full( 315 (5, 5), 316 float("-inf"), 317 ) 318 self.mask = torch.triu(self.mask, diagonal=1) 319 320 def test_sdpa_with_cache_mqa_1(self): 321 q = torch.rand((1, 1, self.n_heads_q, 4)) 322 k = torch.rand((1, 1, self.n_heads_kv, 4)) 323 v = torch.rand((1, 1, self.n_heads_kv, 4)) 324 start_pos = 0 325 seq_len = q.size(1) 326 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 327 attn_mask = attn_mask[:, : start_pos + seq_len] 328 ref_output = _sdpa_with_kv_cache_ref( 329 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 330 ) 331 op_output = torch.ops.llama.sdpa_with_kv_cache( 332 q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False 333 ) 334 self.assertTrue(torch.allclose(ref_output, op_output)) 335 336 def test_sdpa_with_cache_mqa_2(self): 337 q = torch.rand((1, 1, self.n_heads_q, 4)) 338 k = torch.rand((1, 1, self.n_heads_kv, 4)) 339 v = torch.rand((1, 1, self.n_heads_kv, 4)) 340 start_pos = 1 341 seq_len = q.size(1) 342 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 343 attn_mask = attn_mask[:, : start_pos + seq_len] 344 ref_output = _sdpa_with_kv_cache_ref( 345 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 346 ) 347 op_output = torch.ops.llama.sdpa_with_kv_cache( 348 q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False 349 ) 350 self.assertTrue(torch.allclose(ref_output, op_output)) 351 352 def test_sdpa_with_cache_mqa_3(self): 353 self.n_heads_q = 14 354 self.n_heads_kv = 7 355 self.setup_caches() 356 q = torch.rand((1, 1, self.n_heads_q, 4)) 357 k = torch.rand((1, 1, self.n_heads_kv, 4)) 358 v = torch.rand((1, 1, self.n_heads_kv, 4)) 359 start_pos = 1 360 seq_len = q.size(1) 361 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 362 attn_mask = attn_mask[:, : start_pos + seq_len] 363 ref_output = _sdpa_with_kv_cache_ref( 364 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 365 ) 366 op_output = torch.ops.llama.sdpa_with_kv_cache( 367 q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False 368 ) 369 self.assertTrue(torch.allclose(ref_output, op_output)) 370 371 372class SDPATestCommon(unittest.TestCase): 373 374 def setup_caches(self): 375 self.k_cache = torch.zeros( 376 (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) 377 ) 378 self.v_cache = torch.zeros( 379 (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim) 380 ) 381 self.mask = torch.full( 382 (self.max_seq_len, self.max_seq_len), 383 float("-inf"), 384 ) 385 self.mask = torch.triu(self.mask, diagonal=1) 386 387 def setUp(self): 388 torch.manual_seed(42) 389 self.n_batch = 5 390 self.n_heads_kv = 32 391 self.n_heads_q = 32 392 self.head_dim = 128 393 self.max_seq_len = 2048 394 self.setup_caches() 395 396 def _scale_tensor(self, tensor, min_value, max_value, scale=True): 397 normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) 398 399 scaled_tensor = normalized_tensor * (max_value - min_value) + min_value 400 401 return scaled_tensor if scale else tensor 402 403 def _test_sdpa_common( 404 self, 405 n_heads_kv, 406 n_heads_q, 407 head_dim, 408 max_seq_len, 409 seq_len, 410 next_iter_seq_len=1, 411 scale_tensors=False, 412 ): 413 # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests 414 tensor_scale_max = 15 415 tensor_scale_min = -15 416 self.n_heads_kv = n_heads_kv 417 self.n_heads_q = n_heads_q 418 self.head_dim = head_dim 419 self.max_seq_len = max_seq_len 420 self.setup_caches() 421 q = self._scale_tensor( 422 torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), 423 tensor_scale_max, 424 tensor_scale_min, 425 scale_tensors, 426 ) 427 k = self._scale_tensor( 428 torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), 429 tensor_scale_max, 430 tensor_scale_min, 431 scale_tensors, 432 ) 433 v = self._scale_tensor( 434 torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)), 435 tensor_scale_max, 436 tensor_scale_min, 437 scale_tensors, 438 ) 439 440 start_pos = 0 441 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 442 attn_mask = attn_mask[:, : start_pos + seq_len] 443 ref_output = _sdpa_with_kv_cache_ref( 444 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 445 ) 446 op_output = torch.ops.llama.sdpa_with_kv_cache( 447 q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True 448 ) 449 self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) 450 451 q = self._scale_tensor( 452 torch.rand( 453 (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) 454 ), 455 tensor_scale_max, 456 tensor_scale_min, 457 scale_tensors, 458 ) 459 k = self._scale_tensor( 460 torch.rand( 461 (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) 462 ), 463 tensor_scale_max, 464 tensor_scale_min, 465 scale_tensors, 466 ) 467 v = self._scale_tensor( 468 torch.rand( 469 (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim) 470 ), 471 tensor_scale_max, 472 tensor_scale_min, 473 scale_tensors, 474 ) 475 476 start_pos = seq_len 477 seq_len = q.size(1) 478 attn_mask = self.mask[start_pos : start_pos + seq_len, :] 479 attn_mask = attn_mask[:, : start_pos + seq_len] 480 ref_output = _sdpa_with_kv_cache_ref( 481 q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len 482 ) 483 op_output = torch.ops.llama.sdpa_with_kv_cache( 484 q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True 485 ) 486 self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6)) 487 488 489class SDPATestForLargeSeqLength(SDPATestCommon): 490 491 def test_sdpa_with_cache_seq_len_130(self): 492 n_heads_kv = 32 493 n_heads_q = 32 494 head_dim = 128 495 max_seq_len = 2048 496 seq_len = 130 497 self._test_sdpa_common( 498 n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True 499 ) 500 501 def test_sdpa_with_cache_seq_len_small(self): 502 n_heads_kv = 4 503 n_heads_q = 4 504 head_dim = 4 505 max_seq_len = 8 506 seq_len = 4 507 self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len) 508 509 def test_sdpa_with_cache_seq_len_llava_example(self): 510 n_heads_kv = 32 511 n_heads_q = 32 512 head_dim = 128 513 max_seq_len = 2048 514 seq_len = 634 515 self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len) 516 517 def test_sdpa_with_cache_seq_len_130_gqa(self): 518 n_heads_kv = 8 519 n_heads_q = 32 520 head_dim = 128 521 max_seq_len = 2048 522 seq_len = 130 523 self._test_sdpa_common( 524 n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True 525 ) 526 527 def test_sdpa_with_cache_seq_len_llava_example_gqa(self): 528 n_heads_kv = 16 529 n_heads_q = 32 530 head_dim = 128 531 max_seq_len = 2048 532 seq_len = 634 533 self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len) 534 535 536class SDPATestForSpeculativeDecode(SDPATestCommon): 537 538 def test_sdpa_with_cache_seq_len_130(self): 539 n_heads_kv = 32 540 n_heads_q = 32 541 head_dim = 128 542 max_seq_len = 2048 543 seq_len = 130 544 next_iter_seq_len = 17 545 self._test_sdpa_common( 546 n_heads_kv, 547 n_heads_q, 548 head_dim, 549 max_seq_len, 550 seq_len, 551 next_iter_seq_len, 552 True, 553 ) 554 555 def test_sdpa_with_cache_seq_len_llava_example(self): 556 n_heads_kv = 32 557 n_heads_q = 32 558 head_dim = 128 559 max_seq_len = 2048 560 seq_len = 634 561 next_iter_seq_len = 64 562 self._test_sdpa_common( 563 n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len 564 ) 565 566 def test_sdpa_with_cache_seq_len_130_gqa(self): 567 n_heads_kv = 8 568 n_heads_q = 32 569 head_dim = 128 570 max_seq_len = 2048 571 seq_len = 130 572 next_iter_seq_len = 33 573 self._test_sdpa_common( 574 n_heads_kv, 575 n_heads_q, 576 head_dim, 577 max_seq_len, 578 seq_len, 579 next_iter_seq_len, 580 True, 581 ) 582 583 def test_sdpa_with_cache_seq_len_llava_example_gqa(self): 584 n_heads_kv = 16 585 n_heads_q = 32 586 head_dim = 128 587 max_seq_len = 2048 588 seq_len = 634 589 next_iter_seq_len = 117 590 self._test_sdpa_common( 591 n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len 592 ) 593