xref: /aosp_15_r20/external/pytorch/tools/test/test_executorch_gen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import os
4import tempfile
5import unittest
6
7import yaml
8
9from torchgen.executorch.model import ETKernelIndex, ETKernelKey
10from torchgen.gen import LineLoader
11from torchgen.gen_executorch import (
12    ComputeCodegenUnboxedKernels,
13    gen_functions_declarations,
14    parse_yaml_files,
15    translate_native_yaml,
16)
17from torchgen.model import (
18    BackendIndex,
19    BackendMetadata,
20    DispatchKey,
21    Location,
22    NativeFunction,
23    OperatorName,
24)
25from torchgen.selective_build.selector import SelectiveBuilder
26
27
28TEST_YAML = """
29- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
30  device_check: NoCheck   # TensorIterator
31  structured: True
32  structured_inherits: TensorIteratorBase
33  ufunc_inner_loop:
34    Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
35    ScalarOnly: add (Bool)
36  dispatch:
37    SparseCPU: add_out_sparse_cpu
38    SparseCUDA: add_out_sparse_cuda
39    SparseCsrCPU: add_out_sparse_csr_cpu
40    SparseCsrCUDA: add_out_sparse_csr_cuda
41    MkldnnCPU: mkldnn_add_out
42    MPS: add_out_mps
43
44- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
45  device_check: NoCheck   # TensorIterator
46  structured_delegate: add.out
47  variants: function, method
48  dispatch:
49    SparseCPU, SparseCUDA: add_sparse
50    SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
51    MkldnnCPU: mkldnn_add
52    ZeroTensor: add_zerotensor
53    NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor
54  tags: core
55
56- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
57  device_check: NoCheck   # TensorIterator
58  structured: True
59  structured_inherits: TensorIteratorBase
60  dispatch:
61    CPU, CUDA: mul_out
62    MPS: mul_out_mps
63    SparseCPU: mul_out_sparse_cpu
64    SparseCUDA: mul_out_sparse_cuda
65    SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr
66    MkldnnCPU: mkldnn_mul_out
67
68- func: mul.Tensor(Tensor self, Tensor other) -> Tensor
69  device_check: NoCheck   # TensorIterator
70  structured_delegate: mul.out
71  variants: function, method
72  dispatch:
73    SparseCPU, SparseCUDA: mul_sparse
74    SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr
75    MkldnnCPU: mkldnn_mul
76    ZeroTensor: mul_zerotensor
77    NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor
78  tags: core
79
80"""
81
82
83TEST_KERNEL_YAML = """
84- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
85  device_check: NoCheck   # TensorIterator
86  structured: True
87  structured_inherits: TensorIteratorBase
88  ufunc_inner_loop:
89    Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
90    ScalarOnly: add (Bool)
91  type_alias:
92    T0: [Float, Double]
93    T1: [Double, Int]
94  dim_order_alias:
95    D0: [0, 1, 2, 3]
96    D1: [0, 3, 2, 1]
97  kernels:
98    - arg_meta: null
99      kernel_name: default_impl
100    - arg_meta:
101        self: [T0, D0]
102        other: [T1, D0]
103        out: [T0, D0]
104      kernel_name: test_impl
105    - arg_meta:
106        self: [T1, D0]
107        other: [T1, D1]
108        out: [T0, D1]
109      kernel_name: test_impl_2
110
111- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
112  device_check: NoCheck   # TensorIterator
113  structured_delegate: add.out
114  variants: function, method
115  tags: core
116
117- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
118  device_check: NoCheck   # TensorIterator
119  structured: True
120  structured_inherits: TensorIteratorBase
121  type_alias:
122    T0: [Float]
123    T1: [Double]
124  dim_order_alias:
125    D0: [0, 1, 2, 3]
126  kernels:
127    - arg_meta: null
128      kernel_name: default_impl
129    - arg_meta:
130        self: [T0, D0]
131        other: [T1, D0]
132        out: [T0, D0]
133      kernel_name: test_impl
134
135- func: mul.Tensor(Tensor self, Tensor other) -> Tensor
136  device_check: NoCheck   # TensorIterator
137  structured_delegate: mul.out
138  variants: function, method
139  tags: core
140
141"""
142
143
144class TestParseNativeYaml(unittest.TestCase):
145    def setUp(self) -> None:
146        self.temp_dir = tempfile.mkdtemp()
147
148        self.aten_yaml_path = os.path.join(self.temp_dir, "test_native_functions.yaml")
149        with open(self.aten_yaml_path, "w") as f:
150            f.write(TEST_YAML)
151        self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml")
152        self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml")
153        with open(self.tags_yaml_path, "w") as f:
154            f.write(
155                """
156- tag: core
157  desc: test
158            """
159            )
160        with open(self.ops_yaml_path, "w") as f:
161            f.write(
162                """
163- op: add.out
164  device_check: NoCheck   # TensorIterator
165  dispatch:
166    CPU: torch::executor::add_out_kernel
167
168- op: mul.out
169  device_check: NoCheck   # TensorIterator
170  dispatch:
171    CPU: torch::executor::mul_out_kernel
172                """
173            )
174
175    def test_translate_native_yaml_writes_correct_data(self) -> None:
176        out_yaml_path = os.path.join(self.temp_dir, "out.yaml")
177        with open(out_yaml_path, "w") as out_file:
178            translate_native_yaml(
179                tags_yaml_path=self.tags_yaml_path,
180                aten_yaml_path=self.aten_yaml_path,
181                native_yaml_path=self.ops_yaml_path,
182                use_aten_lib=False,
183                out_file=out_file,
184            )
185        with open(out_yaml_path) as out_file:
186            es = yaml.load(out_file, Loader=LineLoader)
187        self.assertTrue(all("func" in e for e in es))
188        self.assertTrue(all(e.get("variants") == "function" for e in es))
189
190        # Check that kernel fields aren't introduced in yaml
191        for e in es:
192            self.assertFalse({"kernels", "type_alias", "dim_order_alias"} < e.keys())
193
194    def test_parse_yaml_files(self) -> None:
195        custom_ops_yaml_path = None
196        selector = SelectiveBuilder.get_nop_selector()
197        use_aten_lib = False
198
199        parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
200            aten_yaml_path=self.aten_yaml_path,
201            tags_yaml_path=self.tags_yaml_path,
202            native_yaml_path=self.ops_yaml_path,
203            custom_ops_yaml_path=custom_ops_yaml_path,
204            selector=selector,
205            use_aten_lib=use_aten_lib,
206        )
207
208        # Just the default kernel entry
209        expected_kernel_entry = {"add.out": 1, "mul.out": 1}
210        self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry))
211
212        op_entries = parsed_yaml.kernel_index.index
213        for op_name, kernel_mapping in op_entries.items():
214            self.assertTrue(
215                len(kernel_mapping) == expected_kernel_entry.pop(str(op_name))
216            )
217
218        self.assertTrue(len(expected_kernel_entry) == 0)
219
220    def tearDown(self) -> None:
221        import shutil
222
223        try:
224            shutil.rmtree(self.temp_dir)
225        except OSError:
226            pass
227
228
229class TestParseKernelYamlFiles(unittest.TestCase):
230    def setUp(self) -> None:
231        self.temp_dir = tempfile.mkdtemp()
232
233        self.aten_kernel_yaml_path = os.path.join(
234            self.temp_dir, "test_kernel_native_functions.yaml"
235        )
236        with open(self.aten_kernel_yaml_path, "w") as f:
237            f.write(TEST_KERNEL_YAML)
238        self.ops_yaml_path = os.path.join(self.temp_dir, "test.yaml")
239        self.tags_yaml_path = os.path.join(self.temp_dir, "tags.yaml")
240        with open(self.tags_yaml_path, "w") as f:
241            f.write(
242                """
243- tag: core
244  desc: test
245            """
246            )
247        with open(self.ops_yaml_path, "w") as f:
248            f.write(
249                """
250- op: add.out
251  device_check: NoCheck   # TensorIterator
252  dispatch:
253    CPU: torch::executor::add_out_kernel
254
255- op: mul.out
256  device_check: NoCheck   # TensorIterator
257  dispatch:
258    CPU: torch::executor::mul_out_kernel
259                """
260            )
261
262    def test_translate_kernel_native_yaml_writes_correct_data(self) -> None:
263        out_yaml_path = os.path.join(self.temp_dir, "out2.yaml")
264        with open(out_yaml_path, "w") as out_file:
265            translate_native_yaml(
266                tags_yaml_path=self.tags_yaml_path,
267                aten_yaml_path=self.aten_kernel_yaml_path,
268                native_yaml_path=self.ops_yaml_path,
269                use_aten_lib=False,
270                out_file=out_file,
271            )
272        with open(out_yaml_path) as out_file:
273            es = yaml.load(out_file, Loader=LineLoader)
274        self.assertTrue(all("func" in e for e in es))
275        self.assertTrue(all(e.get("variants") == "function" for e in es))
276
277        # Check persistence of kernel fields in yaml
278        for e in es:
279            self.assertTrue({"kernels", "type_alias", "dim_order_alias"} < e.keys())
280
281    def test_parse_yaml_files(self) -> None:
282        custom_ops_yaml_path = None
283        selector = SelectiveBuilder.get_nop_selector()
284        use_aten_lib = False
285
286        parsed_yaml, custom_ops_parsed_yaml = parse_yaml_files(
287            aten_yaml_path=self.aten_kernel_yaml_path,
288            tags_yaml_path=self.tags_yaml_path,
289            native_yaml_path=self.ops_yaml_path,
290            custom_ops_yaml_path=custom_ops_yaml_path,
291            selector=selector,
292            use_aten_lib=use_aten_lib,
293        )
294
295        expected_kernel_entry = {"add.out": 9, "mul.out": 2}
296        self.assertTrue(len(parsed_yaml.native_functions) == len(expected_kernel_entry))
297
298        op_entries = parsed_yaml.kernel_index.index
299        for op_name, kernel_mapping in op_entries.items():
300            self.assertTrue(
301                len(kernel_mapping) == expected_kernel_entry.pop(str(op_name))
302            )
303
304        self.assertTrue(len(expected_kernel_entry) == 0)
305
306    def tearDown(self) -> None:
307        import shutil
308
309        try:
310            shutil.rmtree(self.temp_dir)
311        except OSError:
312            pass
313
314
315class TestGenFunctionsDeclarations(unittest.TestCase):
316    def setUp(self) -> None:
317        (
318            self.custom_1_native_function,
319            custom_1_backend_index,
320        ) = NativeFunction.from_yaml(
321            {"func": "custom_1::op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
322            loc=Location(__file__, 1),
323            valid_tags=set(),
324        )
325        (
326            self.custom_2_native_function,
327            custom_2_backend_index,
328        ) = NativeFunction.from_yaml(
329            {
330                "func": "custom_2::op_2() -> bool",
331                "dispatch": {"CPU": "kernel_2"},
332            },
333            loc=Location(__file__, 1),
334            valid_tags=set(),
335        )
336        (
337            self.custom_3_native_function,
338            custom_3_backend_index,
339        ) = NativeFunction.from_yaml(
340            {
341                "func": "custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!)",
342                "dispatch": {"CPU": "kernel_3"},
343                "variants": "method",
344            },
345            loc=Location(__file__, 1),
346            valid_tags=set(),
347        )
348
349        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
350            DispatchKey.CPU: {},
351            DispatchKey.QuantizedCPU: {},
352        }
353        BackendIndex.grow_index(backend_indices, custom_1_backend_index)
354        BackendIndex.grow_index(backend_indices, custom_2_backend_index)
355        self.static_dispatch_idx = [
356            BackendIndex(
357                dispatch_key=k,
358                use_out_as_primary=True,
359                external=False,
360                device_guard=False,
361                index=backend_indices[k],
362            )
363            for k in backend_indices
364        ]
365        self.kernel_index = ETKernelIndex.from_backend_indices(backend_indices)
366
367    def test_operators_with_different_namespaces_are_grouped_correctly(self) -> None:
368        declarations = gen_functions_declarations(
369            native_functions=[
370                self.custom_1_native_function,
371                self.custom_2_native_function,
372            ],
373            kernel_index=self.kernel_index,
374            selector=SelectiveBuilder.get_nop_selector(),
375            use_aten_lib=False,
376        )
377        self.assertTrue(
378            """
379namespace custom_1 {
380
381// custom_1::op_1() -> bool
382TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) {
383    return ::at::native::kernel_1(context);
384}
385
386} // namespace custom_1
387"""
388            in declarations
389        )
390
391        self.assertTrue(
392            """
393namespace custom_2 {
394
395// custom_2::op_2() -> bool
396TORCH_API inline bool op_2(torch::executor::KernelRuntimeContext & context) {
397    return ::at::native::kernel_2(context);
398}
399
400} // namespace custom_2
401        """
402            in declarations
403        )
404
405    def test_aten_lib_has_context_arg(self) -> None:
406        declarations = gen_functions_declarations(
407            native_functions=[
408                self.custom_1_native_function,
409            ],
410            kernel_index=self.kernel_index,
411            selector=SelectiveBuilder.get_nop_selector(),
412            use_aten_lib=True,
413        )
414        self.assertTrue(
415            """
416namespace custom_1 {
417
418// custom_1::op_1() -> bool
419TORCH_API inline bool op_1(torch::executor::KernelRuntimeContext & context) {
420    return at::op_1();
421}
422
423} // namespace custom_1
424        """
425            in declarations
426        )
427
428    def test_aten_lib_method_variant(self) -> None:
429        declarations = gen_functions_declarations(
430            native_functions=[
431                self.custom_3_native_function,
432            ],
433            kernel_index=self.kernel_index,
434            selector=SelectiveBuilder.get_nop_selector(),
435            use_aten_lib=True,
436        )
437        self.assertTrue(
438            """
439namespace custom_3 {
440
441// custom_3::op_3(Tensor(a!) self, Tensor x) -> Tensor(a!)
442TORCH_API inline at::Tensor & op_3(torch::executor::KernelRuntimeContext & context, at::Tensor & self, const at::Tensor & x) {
443    return self.op_3(x);
444}
445
446} // namespace custom_3
447        """
448            in declarations
449        )
450
451
452class TestComputeCodegenUnboxedKernels(unittest.TestCase):
453    def setUp(self) -> None:
454        (
455            self.native_function_no_kern,
456            _,
457        ) = NativeFunction.from_yaml(
458            {
459                "func": "custom_1::op_1() -> bool",
460                "dispatch": {"CPU": "unused_kernel_1"},
461            },
462            loc=Location(__file__, 1),
463            valid_tags=set(),
464        )
465
466        self.default_kernel_key = ETKernelKey(default=True)
467        self.default_backend_metadata = BackendMetadata(
468            "default_kernel", False, "at::native"
469        )
470        self.default_kernel_entry = (
471            [self.default_kernel_key],
472            self.default_backend_metadata,
473        )
474
475    def test_codegen_unboxed_specialized(self) -> None:
476        specialized_kernel_key = ETKernelKey.gen_from_yaml(
477            {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")},
478            {"T0": ["Double"]},
479            {"D0": [0, 1, 2, 3]},
480        )
481        selector = SelectiveBuilder.from_yaml_dict(
482            {
483                "include_all_operators": True,
484                "et_kernel_metadata": {
485                    "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"]
486                },
487            }
488        )
489        use_aten_lib = False
490        entry = (
491            self.native_function_no_kern,
492            (specialized_kernel_key, self.default_backend_metadata),
493        )
494
495        result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
496        # Concat used to prevent whitespace stripping
497        expected_str = (
498            """
499Kernel(
500    "custom_1::op_1",
501    "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
502    [](torch::executor::KernelRuntimeContext & context, EValue** stack) {
503        """
504            + """
505
506        internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1");
507        EXECUTORCH_SCOPE_PROF("native_call_op_1");
508        bool result_ = at::native::default_kernel(context, );
509        internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
510
511        *stack[0] = EValue(result_);
512    }
513),
514"""
515        )
516
517        self.assertEqual(expected_str, result)
518
519    def test_codegen_unboxed_specialized_not_matching(self) -> None:
520        specialized_kernel_key = ETKernelKey.gen_from_yaml(
521            {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")},
522            {"T0": ["Double"]},
523            {"D0": [0, 1, 2, 3]},
524        )
525        selector = SelectiveBuilder.from_yaml_dict(
526            {
527                "include_all_operators": True,
528                "et_kernel_metadata": {
529                    "custom_1::op_1": ["v1/8;0,1,2,3|7;0,1,2,3|7;0,1,2,3"]
530                },
531            }
532        )
533        use_aten_lib = False
534        entry = (
535            self.native_function_no_kern,
536            (specialized_kernel_key, self.default_backend_metadata),
537        )
538
539        self.assertRaises(
540            Exception, ComputeCodegenUnboxedKernels(selector, use_aten_lib), entry
541        )
542
543    def test_codegen_unboxed_specialized_missing_root_op(self) -> None:
544        specialized_kernel_key = ETKernelKey.gen_from_yaml(
545            {"self": ("T0", "D0"), "other": ("T0", "D0"), "out": ("T0", "D0")},
546            {"T0": ["Double"]},
547            {"D0": [0, 1, 2, 3]},
548        )
549        selector = SelectiveBuilder.from_yaml_dict(
550            {
551                "et_kernel_metadata": {
552                    "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"]
553                }
554            }
555        )
556        use_aten_lib = False
557        entry = (
558            self.native_function_no_kern,
559            (specialized_kernel_key, self.default_backend_metadata),
560        )
561
562        result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
563        # Concat used to prevent whitespace stripping
564        expected_str = """"""
565
566        self.assertEqual(expected_str, result)
567
568    def test_codegen_unboxed_default(self) -> None:
569        """
570        This test checks that if there is no specialized kernel, the default kernel is used.
571        """
572        selector = SelectiveBuilder.from_yaml_dict(
573            {
574                "include_all_operators": True,
575                "et_kernel_metadata": {
576                    "custom_1::op_1": ["v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3"]
577                },
578            }
579        )
580        use_aten_lib = False
581        entry = (self.native_function_no_kern, self.default_kernel_entry)
582
583        result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
584        # Concat used to prevent whitespace stripping
585        expected_str = (
586            """
587Kernel(
588    "custom_1::op_1",
589    [](torch::executor::KernelRuntimeContext & context, EValue** stack) {
590        """
591            + """
592
593        internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1");
594        EXECUTORCH_SCOPE_PROF("native_call_op_1");
595        bool result_ = at::native::default_kernel(context, );
596        internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
597
598        *stack[0] = EValue(result_);
599    }
600),
601"""
602        )
603
604        self.assertEqual(expected_str, result)
605
606    def test_codegen_unboxed_default_kernel_key_selected(self) -> None:
607        """
608        This test checks that if there is no specialized kernel, the default kernel is used, when the selector only has default key.
609        """
610        selector = SelectiveBuilder.from_yaml_dict(
611            {
612                "include_all_operators": True,
613                "et_kernel_metadata": {"custom_1::op_1": ["default"]},
614            }
615        )
616        use_aten_lib = False
617        entry = (self.native_function_no_kern, self.default_kernel_entry)
618
619        result = ComputeCodegenUnboxedKernels(selector, use_aten_lib)(entry)
620        # Concat used to prevent whitespace stripping
621        expected_str = (
622            """
623Kernel(
624    "custom_1::op_1",
625    [](torch::executor::KernelRuntimeContext & context, EValue** stack) {
626        """
627            + """
628
629        internal::EventTracerProfileScope event_tracer_scope(context.internal_event_tracer(), "native_call_op_1");
630        EXECUTORCH_SCOPE_PROF("native_call_op_1");
631        bool result_ = at::native::default_kernel(context, );
632        internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]);
633
634        *stack[0] = EValue(result_);
635    }
636),
637"""
638        )
639
640        self.assertEqual(expected_str, result)
641