1from __future__ import annotations 2 3import os 4import tempfile 5import unittest 6 7import yaml 8 9from torchgen.executorch.model import ETKernelIndex, ETKernelKey 10from torchgen.gen import LineLoader 11from torchgen.gen_executorch import ( 12 ComputeCodegenUnboxedKernels, 13 gen_functions_declarations, 14 parse_yaml_files, 15 translate_native_yaml, 16) 17from torchgen.model import ( 18 BackendIndex, 19 BackendMetadata, 20 DispatchKey, 21 Location, 22 NativeFunction, 23 OperatorName, 24) 25from torchgen.selective_build.selector import SelectiveBuilder 26 27 28TEST_YAML = """ 29- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 30 device_check: NoCheck # TensorIterator 31 structured: True 32 structured_inherits: TensorIteratorBase 33 ufunc_inner_loop: 34 Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) 35 ScalarOnly: add (Bool) 36 dispatch: 37 SparseCPU: add_out_sparse_cpu 38 SparseCUDA: add_out_sparse_cuda 39 SparseCsrCPU: add_out_sparse_csr_cpu 40 SparseCsrCUDA: add_out_sparse_csr_cuda 41 MkldnnCPU: mkldnn_add_out 42 MPS: add_out_mps 43 44- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 45 device_check: NoCheck # TensorIterator 46 structured_delegate: add.out 47 variants: function, method 48 dispatch: 49 SparseCPU, SparseCUDA: add_sparse 50 SparseCsrCPU, SparseCsrCUDA: add_sparse_csr 51 MkldnnCPU: mkldnn_add 52 ZeroTensor: add_zerotensor 53 NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor 54 tags: core 55 56- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 57 device_check: NoCheck # TensorIterator 58 structured: True 59 structured_inherits: TensorIteratorBase 60 dispatch: 61 CPU, CUDA: mul_out 62 MPS: mul_out_mps 63 SparseCPU: mul_out_sparse_cpu 64 SparseCUDA: mul_out_sparse_cuda 65 SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr 66 MkldnnCPU: mkldnn_mul_out 67 68- func: mul.Tensor(Tensor self, Tensor other) -> Tensor 69 device_check: NoCheck # TensorIterator 70 structured_delegate: mul.out 71 variants: function, method 72 dispatch: 73 SparseCPU, SparseCUDA: mul_sparse 74 SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr 75 MkldnnCPU: mkldnn_mul 76 ZeroTensor: mul_zerotensor 77 NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor 78 tags: core 79 80""" 81 82 83TEST_KERNEL_YAML = """ 84- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 85 device_check: NoCheck # TensorIterator 86 structured: True 87 structured_inherits: TensorIteratorBase 88 ufunc_inner_loop: 89 Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) 90 ScalarOnly: add (Bool) 91 type_alias: 92 T0: [Float, Double] 93 T1: [Double, Int] 94 dim_order_alias: 95 D0: [0, 1, 2, 3] 96 D1: [0, 3, 2, 1] 97 kernels: 98 - arg_meta: null 99 kernel_name: default_impl 100 - arg_meta: 101 self: [T0, D0] 102 other: [T1, D0] 103 out: [T0, D0] 104 kernel_name: test_impl 105 - arg_meta: 106 self: [T1, D0] 107 other: [T1, D1] 108 out: [T0, D1] 109 kernel_name: test_impl_2 110 111- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 112 device_check: NoCheck # TensorIterator 113 structured_delegate: add.out 114 variants: function, method 115 tags: core 116 117- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 118 device_check: NoCheck # TensorIterator 119 structured: True 120 structured_inherits: TensorIteratorBase 121 type_alias: 122 T0: [Float] 123 T1: [Double] 124 dim_order_alias: 125 D0: [0, 1, 2, 3] 126 kernels: 127 - arg_meta: null 128 kernel_name: default_impl 129 - arg_meta: 130 self: [T0, D0] 131 other: [T1, D0] 132 out: [T0, D0] 133 kernel_name: test_impl 134 135- func: mul.Tensor(Tensor self, Tensor other) -> Tensor 136 device_check: NoCheck # TensorIterator 137 structured_delegate: mul.out 138 variants: function, method 139 tags: core 140 141""" 142 143 144class TestParseNativeYaml(unittest.TestCase): 145 def setUp(self) -> None: 146 self.temp_dir = tempfile.mkdtemp() 147 148 self.aten_yaml_path = os.path.join(self.temp_dir, "test_native_functions.yaml") 149 with open(self.aten_yaml_path, "w") as f: 150 f.write(TEST_YAML) 151 self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") 152 self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") 153 with open(self.tags_yaml_path, "w") as f: 154 f.write( 155 """ 156- tag: core 157 desc: test 158 """ 159 ) 160 with open(self.ops_yaml_path, "w") as f: 161 f.write( 162 """ 163- op: add.out 164 device_check: NoCheck # TensorIterator 165 dispatch: 166 CPU: torch::executor::add_out_kernel 167 168- op: mul.out 169 device_check: NoCheck # TensorIterator 170 dispatch: 171 CPU: torch::executor::mul_out_kernel 172 """ 173 ) 174 175 def test_translate_native_yaml_writes_correct_data(self) -> None: 176 out_yaml_path = os.path.join(self.temp_dir, "out.yaml") 177 with open(out_yaml_path, "w") as out_file: 178 translate_native_yaml( 179 tags_yaml_path=self.tags_yaml_path, 180 aten_yaml_path=self.aten_yaml_path, 181 native_yaml_path=self.ops_yaml_path, 182 use_aten_lib=False, 183 out_file=out_file, 184 ) 185 with open(out_yaml_path) as out_file: 186 es = yaml.load(out_file, Loader=LineLoader) 187 self.assertTrue(all("func" in e for e in es)) 188 self.assertTrue(all(e.get("variants") == "function" for e in es)) 189 190 # Check that kernel fields aren't introduced in yaml 191 for e in es: 192 self.assertFalse({"kernels", "type_alias", "dim_order_alias"} < e.keys()) 193 194 def test_parse_yaml_files(self) -> None: 195 custom_ops_yaml_path = None 196 selector = SelectiveBuilder.get_nop_selector() 197 use_aten_lib = False 198 199 parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( 200 aten_yaml_path=self.aten_yaml_path, 201 tags_yaml_path=self.tags_yaml_path, 202 native_yaml_path=self.ops_yaml_path, 203 custom_ops_yaml_path=custom_ops_yaml_path, 204 selector=selector, 205 use_aten_lib=use_aten_lib, 206 ) 207 208 # Just the default kernel entry 209 expected_kernel_entry = {"add.out": 1, "mul.out": 1} 210 self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) 211 212 op_entries = parsed_yaml.kernel_index.index 213 for op_name, kernel_mapping in op_entries.items(): 214 self.assertTrue( 215 len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) 216 ) 217 218 self.assertTrue(len(expected_kernel_entry) == 0) 219 220 def tearDown(self) -> None: 221 import shutil 222 223 try: 224 shutil.rmtree(self.temp_dir) 225 except OSError: 226 pass 227 228 229class TestParseKernelYamlFiles(unittest.TestCase): 230 def setUp(self) -> None: 231 self.temp_dir = tempfile.mkdtemp() 232 233 self.aten_kernel_yaml_path = os.path.join( 234 self.temp_dir, "test_kernel_native_functions.yaml" 235 ) 236 with open(self.aten_kernel_yaml_path, "w") as f: 237 f.write(TEST_KERNEL_YAML) 238 self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml") 239 self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml") 240 with open(self.tags_yaml_path, "w") as f: 241 f.write( 242 """ 243- tag: core 244 desc: test 245 """ 246 ) 247 with open(self.ops_yaml_path, "w") as f: 248 f.write( 249 """ 250- op: add.out 251 device_check: NoCheck # TensorIterator 252 dispatch: 253 CPU: torch::executor::add_out_kernel 254 255- op: mul.out 256 device_check: NoCheck # TensorIterator 257 dispatch: 258 CPU: torch::executor::mul_out_kernel 259 """ 260 ) 261 262 def test_translate_kernel_native_yaml_writes_correct_data(self) -> None: 263 out_yaml_path = os.path.join(self.temp_dir, "out2.yaml") 264 with open(out_yaml_path, "w") as out_file: 265 translate_native_yaml( 266 tags_yaml_path=self.tags_yaml_path, 267 aten_yaml_path=self.aten_kernel_yaml_path, 268 native_yaml_path=self.ops_yaml_path, 269 use_aten_lib=False, 270 out_file=out_file, 271 ) 272 with open(out_yaml_path) as out_file: 273 es = yaml.load(out_file, Loader=LineLoader) 274 self.assertTrue(all("func" in e for e in es)) 275 self.assertTrue(all(e.get("variants") == "function" for e in es)) 276 277 # Check persistence of kernel fields in yaml 278 for e in es: 279 self.assertTrue({"kernels", "type_alias", "dim_order_alias"} < e.keys()) 280 281 def test_parse_yaml_files(self) -> None: 282 custom_ops_yaml_path = None 283 selector = SelectiveBuilder.get_nop_selector() 284 use_aten_lib = False 285 286 parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files( 287 aten_yaml_path=self.aten_kernel_yaml_path, 288 tags_yaml_path=self.tags_yaml_path, 289 native_yaml_path=self.ops_yaml_path, 290 custom_ops_yaml_path=custom_ops_yaml_path, 291 selector=selector, 292 use_aten_lib=use_aten_lib, 293 ) 294 295 expected_kernel_entry = {"add.out": 9, "mul.out": 2} 296 self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry)) 297 298 op_entries = parsed_yaml.kernel_index.index 299 for op_name, kernel_mapping in op_entries.items(): 300 self.assertTrue( 301 len(kernel_mapping) == expected_kernel_entry.pop(str(op_name)) 302 ) 303 304 self.assertTrue(len(expected_kernel_entry) == 0) 305 306 def tearDown(self) -> None: 307 import shutil 308 309 try: 310 shutil.rmtree(self.temp_dir) 311 except OSError: 312 pass 313 314 315class TestGenFunctionsDeclarations(unittest.TestCase): 316 def setUp(self) -> None: 317 ( 318 self.custom_1_native_function, 319 custom_1_backend_index, 320 ) = NativeFunction.from_yaml( 321 {"func": "custom_1::op_1() -> bool", "dispatch": {"CPU": "kernel_1"}}, 322 loc=Location(__file__, 1), 323 valid_tags=set(), 324 ) 325 ( 326 self.custom_2_native_function, 327 custom_2_backend_index, 328 ) = NativeFunction.from_yaml( 329 { 330 "func": "custom_2::op_2() -> bool", 331 "dispatch": {"CPU": "kernel_2"}, 332 }, 333 loc=Location(__file__, 1), 334 valid_tags=set(), 335 ) 336 ( 337 self.custom_3_native_function, 338 custom_3_backend_index, 339 ) = NativeFunction.from_yaml( 340 { 341 "func": "custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!)", 342 "dispatch": {"CPU": "kernel_3"}, 343 "variants": "method", 344 }, 345 loc=Location(__file__, 1), 346 valid_tags=set(), 347 ) 348 349 backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = { 350 DispatchKey.CPU: {}, 351 DispatchKey.QuantizedCPU: {}, 352 } 353 BackendIndex.grow_index(backend_indices, custom_1_backend_index) 354 BackendIndex.grow_index(backend_indices, custom_2_backend_index) 355 self.static_dispatch_idx = [ 356 BackendIndex( 357 dispatch_key=k, 358 use_out_as_primary=True, 359 external=False, 360 device_guard=False, 361 index=backend_indices[k], 362 ) 363 for k in backend_indices 364 ] 365 self.kernel_index = ETKernelIndex.from_backend_indices(backend_indices) 366 367 def test_operators_with_different_namespaces_are_grouped_correctly(self) -> None: 368 declarations = gen_functions_declarations( 369 native_functions=[ 370 self.custom_1_native_function, 371 self.custom_2_native_function, 372 ], 373 kernel_index=self.kernel_index, 374 selector=SelectiveBuilder.get_nop_selector(), 375 use_aten_lib=False, 376 ) 377 self.assertTrue( 378 """ 379namespace custom_1 { 380 381// custom_1::op_1() -> bool 382TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { 383 return ::at::native::kernel_1(context); 384} 385 386} // namespace custom_1 387""" 388 in declarations 389 ) 390 391 self.assertTrue( 392 """ 393namespace custom_2 { 394 395// custom_2::op_2() -> bool 396TORCH_API inline bool op_2(torch::executor::KernelRuntimeContext & context) { 397 return ::at::native::kernel_2(context); 398} 399 400} // namespace custom_2 401 """ 402 in declarations 403 ) 404 405 def test_aten_lib_has_context_arg(self) -> None: 406 declarations = gen_functions_declarations( 407 native_functions=[ 408 self.custom_1_native_function, 409 ], 410 kernel_index=self.kernel_index, 411 selector=SelectiveBuilder.get_nop_selector(), 412 use_aten_lib=True, 413 ) 414 self.assertTrue( 415 """ 416namespace custom_1 { 417 418// custom_1::op_1() -> bool 419TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) { 420 return at::op_1(); 421} 422 423} // namespace custom_1 424 """ 425 in declarations 426 ) 427 428 def test_aten_lib_method_variant(self) -> None: 429 declarations = gen_functions_declarations( 430 native_functions=[ 431 self.custom_3_native_function, 432 ], 433 kernel_index=self.kernel_index, 434 selector=SelectiveBuilder.get_nop_selector(), 435 use_aten_lib=True, 436 ) 437 self.assertTrue( 438 """ 439namespace custom_3 { 440 441// custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!) 442TORCH_API inline at::Tensor & op_3(torch::executor::KernelRuntimeContext & context, at::Tensor & self, const at::Tensor & x) { 443 return self.op_3(x); 444} 445 446} // namespace custom_3 447 """ 448 in declarations 449 ) 450 451 452class TestComputeCodegenUnboxedKernels(unittest.TestCase): 453 def setUp(self) -> None: 454 ( 455 self.native_function_no_kern, 456 _, 457 ) = NativeFunction.from_yaml( 458 { 459 "func": "custom_1::op_1() -> bool", 460 "dispatch": {"CPU": "unused_kernel_1"}, 461 }, 462 loc=Location(__file__, 1), 463 valid_tags=set(), 464 ) 465 466 self.default_kernel_key = ETKernelKey(default=True) 467 self.default_backend_metadata = BackendMetadata( 468 "default_kernel", False, "at::native" 469 ) 470 self.default_kernel_entry = ( 471 [self.default_kernel_key], 472 self.default_backend_metadata, 473 ) 474 475 def test_codegen_unboxed_specialized(self) -> None: 476 specialized_kernel_key = ETKernelKey.gen_from_yaml( 477 {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, 478 {"T0": ["Double"]}, 479 {"D0": [0, 1, 2, 3]}, 480 ) 481 selector = SelectiveBuilder.from_yaml_dict( 482 { 483 "include_all_operators": True, 484 "et_kernel_metadata": { 485 "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] 486 }, 487 } 488 ) 489 use_aten_lib = False 490 entry = ( 491 self.native_function_no_kern, 492 (specialized_kernel_key, self.default_backend_metadata), 493 ) 494 495 result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry) 496 # Concat used to prevent whitespace stripping 497 expected_str = ( 498 """ 499Kernel( 500 "custom_1::op_1", 501 "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3", 502 [](torch::executor::KernelRuntimeContext & context, EValue** stack) { 503 """ 504 + """ 505 506 internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1"); 507 EXECUTORCH_SCOPE_PROF("native_call_op_1"); 508 bool result_ = at::native::default_kernel(context, ); 509 internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); 510 511 *stack[0] = EValue(result_); 512 } 513), 514""" 515 ) 516 517 self.assertEqual(expected_str, result) 518 519 def test_codegen_unboxed_specialized_not_matching(self) -> None: 520 specialized_kernel_key = ETKernelKey.gen_from_yaml( 521 {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, 522 {"T0": ["Double"]}, 523 {"D0": [0, 1, 2, 3]}, 524 ) 525 selector = SelectiveBuilder.from_yaml_dict( 526 { 527 "include_all_operators": True, 528 "et_kernel_metadata": { 529 "custom_1::op_1": ["v1/8;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] 530 }, 531 } 532 ) 533 use_aten_lib = False 534 entry = ( 535 self.native_function_no_kern, 536 (specialized_kernel_key, self.default_backend_metadata), 537 ) 538 539 self.assertRaises( 540 Exception, ComputeCodegenUnboxedKernels(selector, use_aten_lib), entry 541 ) 542 543 def test_codegen_unboxed_specialized_missing_root_op(self) -> None: 544 specialized_kernel_key = ETKernelKey.gen_from_yaml( 545 {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")}, 546 {"T0": ["Double"]}, 547 {"D0": [0, 1, 2, 3]}, 548 ) 549 selector = SelectiveBuilder.from_yaml_dict( 550 { 551 "et_kernel_metadata": { 552 "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] 553 } 554 } 555 ) 556 use_aten_lib = False 557 entry = ( 558 self.native_function_no_kern, 559 (specialized_kernel_key, self.default_backend_metadata), 560 ) 561 562 result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry) 563 # Concat used to prevent whitespace stripping 564 expected_str = """""" 565 566 self.assertEqual(expected_str, result) 567 568 def test_codegen_unboxed_default(self) -> None: 569 """ 570 This test checks that if there is no specialized kernel, the default kernel is used. 571 """ 572 selector = SelectiveBuilder.from_yaml_dict( 573 { 574 "include_all_operators": True, 575 "et_kernel_metadata": { 576 "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"] 577 }, 578 } 579 ) 580 use_aten_lib = False 581 entry = (self.native_function_no_kern, self.default_kernel_entry) 582 583 result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry) 584 # Concat used to prevent whitespace stripping 585 expected_str = ( 586 """ 587Kernel( 588 "custom_1::op_1", 589 [](torch::executor::KernelRuntimeContext & context, EValue** stack) { 590 """ 591 + """ 592 593 internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1"); 594 EXECUTORCH_SCOPE_PROF("native_call_op_1"); 595 bool result_ = at::native::default_kernel(context, ); 596 internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); 597 598 *stack[0] = EValue(result_); 599 } 600), 601""" 602 ) 603 604 self.assertEqual(expected_str, result) 605 606 def test_codegen_unboxed_default_kernel_key_selected(self) -> None: 607 """ 608 This test checks that if there is no specialized kernel, the default kernel is used, when the selector only has default key. 609 """ 610 selector = SelectiveBuilder.from_yaml_dict( 611 { 612 "include_all_operators": True, 613 "et_kernel_metadata": {"custom_1::op_1": ["default"]}, 614 } 615 ) 616 use_aten_lib = False 617 entry = (self.native_function_no_kern, self.default_kernel_entry) 618 619 result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry) 620 # Concat used to prevent whitespace stripping 621 expected_str = ( 622 """ 623Kernel( 624 "custom_1::op_1", 625 [](torch::executor::KernelRuntimeContext & context, EValue** stack) { 626 """ 627 + """ 628 629 internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1"); 630 EXECUTORCH_SCOPE_PROF("native_call_op_1"); 631 bool result_ = at::native::default_kernel(context, ); 632 internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); 633 634 *stack[0] = EValue(result_); 635 } 636), 637""" 638 ) 639 640 self.assertEqual(expected_str, result) 641