xref: /aosp_15_r20/external/pytorch/tools/test/test_codegen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import dataclasses
4import typing
5import unittest
6from collections import defaultdict
7
8import yaml
9from tools.autograd import gen_autograd_functions, load_derivatives
10
11from torchgen import dest
12from torchgen.api.types import CppSignatureGroup, DispatcherSignature
13from torchgen.context import native_function_manager
14from torchgen.gen import (
15    get_native_function_declarations,
16    get_native_function_schema_registrations,
17    LineLoader,
18    static_dispatch,
19)
20from torchgen.model import (
21    BackendIndex,
22    BackendMetadata,
23    DispatchKey,
24    FunctionSchema,
25    Location,
26    NativeFunction,
27    OperatorName,
28)
29from torchgen.native_function_generation import add_generated_native_functions
30from torchgen.selective_build.selector import SelectiveBuilder
31
32
33class TestCreateDerivative(unittest.TestCase):
34    def test_named_grads(self) -> None:
35        schema = FunctionSchema.parse(
36            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
37        )
38        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
39
40        derivative = load_derivatives.create_derivative(
41            native_function,
42            formula="func_backward(grad_x, grad_y)",
43            var_names=(),
44            available_named_gradients=["grad_x", "grad_y"],
45        )
46        self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
47
48    def test_non_differentiable_output(self) -> None:
49        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
50        schema = FunctionSchema.parse(specification)
51        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
52
53        _, differentiability_info = load_derivatives.create_differentiability_info(
54            defn_dict={
55                "name": specification,
56                "dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}},
57            },
58            functions_by_signature={schema.signature(): [native_function]},
59            functions_by_schema={specification: native_function},
60            op_counter=typing.Counter[str](),
61            used_dispatch_keys=set(),
62        )
63
64        self.assertSequenceEqual(
65            differentiability_info["Default"].available_named_gradients,
66            # grad_y is not present because y is a
67            # bool and thus not differentiable.
68            ["grad_x", "grad_z"],
69        )
70
71    def test_indexed_grads(self) -> None:
72        schema = FunctionSchema.parse(
73            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
74        )
75        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
76
77        derivative = load_derivatives.create_derivative(
78            native_function,
79            formula="func_backward(grads[0], grads[1])",
80            var_names=(),
81            available_named_gradients=["grad_x", "grad_y"],
82        )
83        self.assertSetEqual(derivative.named_gradients, set())
84
85    def test_named_grads_and_indexed_grads(self) -> None:
86        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
87        schema = FunctionSchema.parse(specification)
88        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
89
90        with self.assertRaisesRegex(
91            RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
92        ):
93            load_derivatives.create_differentiability_info(
94                defn_dict={
95                    "name": specification,
96                    # Uh-oh, the derivatives reference gradients by
97                    # name and by index.
98                    "dispatch": {
99                        "Default": {
100                            "a": "grad_x",
101                            "b": "grads[1]",
102                        }
103                    },
104                },
105                functions_by_signature={schema.signature(): [native_function]},
106                functions_by_schema={specification: native_function},
107                op_counter=typing.Counter[str](),
108                used_dispatch_keys=set(),
109            )
110
111
112class TestGenAutogradFunctions(unittest.TestCase):
113    def test_non_differentiable_output_invalid_type(self) -> None:
114        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
115        schema = FunctionSchema.parse(specification)
116        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
117
118        _, differentiability_info = load_derivatives.create_differentiability_info(
119            defn_dict={
120                "name": specification,
121                "dispatch": {
122                    "Default": {
123                        "a": "grad_x",
124                        "b": "grad_z",
125                    }
126                },
127            },
128            functions_by_signature={schema.signature(): [native_function]},
129            functions_by_schema={specification: native_function},
130            op_counter=typing.Counter[str](),
131            used_dispatch_keys=set(),
132        )
133        definition = gen_autograd_functions.process_function(
134            differentiability_info["Default"],
135            gen_autograd_functions.FUNCTION_DEFINITION,
136        )
137        # grad_z should map to grads[1], not grads[2] because output 1
138        # (y) is not differentiable.
139        assert "grad_z = grads[2]" not in definition
140        assert "grad_z = grads[1]" in definition
141
142    def test_non_differentiable_output_output_differentiability(self) -> None:
143        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
144        schema = FunctionSchema.parse(specification)
145        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
146
147        _, differentiability_info = load_derivatives.create_differentiability_info(
148            defn_dict={
149                "name": specification,
150                "dispatch": {
151                    "Default": {
152                        "a": "grad_x",
153                        "b": "grad_z",
154                    },
155                    "AutogradNestedTensor": {
156                        "a": "grad_z",
157                        "b": "grad_x",
158                    },
159                },
160                "output_differentiability": [True, False, True],
161            },
162            functions_by_signature={schema.signature(): [native_function]},
163            functions_by_schema={specification: native_function},
164            op_counter=typing.Counter[str](),
165            used_dispatch_keys=set(),
166        )
167        default_definition = gen_autograd_functions.process_function(
168            differentiability_info["Default"],
169            gen_autograd_functions.FUNCTION_DEFINITION,
170        )
171        # grad_z should map to grads[1], not grads[2] because output 1
172        # (y) is not differentiable.
173        assert "grad_z = grads[2]" not in default_definition
174        assert "grad_z = grads[1]" in default_definition
175
176        nested_tensor_definition = gen_autograd_functions.process_function(
177            differentiability_info["AutogradNestedTensor"],
178            gen_autograd_functions.FUNCTION_DEFINITION,
179        )
180        assert "grad_z = grads[2]" not in nested_tensor_definition
181        assert "grad_z = grads[1]" in nested_tensor_definition
182
183    def test_register_bogus_dispatch_key(self) -> None:
184        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
185        schema = FunctionSchema.parse(specification)
186        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
187
188        with self.assertRaisesRegex(
189            RuntimeError,
190            "Invalid dispatch key AutogradRandomTensor in derivatives.yaml for",
191        ):
192            load_derivatives.create_differentiability_info(
193                defn_dict={
194                    "name": specification,
195                    "dispatch": {
196                        "Default": {
197                            "a": "grad_x",
198                            "b": "grad_z",
199                        },
200                        "AutogradRandomTensor": {
201                            "a": "grad_x",
202                            "b": "grad_z",
203                        },
204                    },
205                },
206                functions_by_signature={schema.signature(): [native_function]},
207                functions_by_schema={specification: native_function},
208                op_counter=typing.Counter[str](),
209                used_dispatch_keys=set(),
210            )
211
212
213class TestGenSchemaRegistration(unittest.TestCase):
214    def setUp(self) -> None:
215        self.selector = SelectiveBuilder.get_nop_selector()
216        self.custom_native_function, _ = NativeFunction.from_yaml(
217            {"func": "custom::func() -> bool"},
218            loc=Location(__file__, 1),
219            valid_tags=set(),
220        )
221        (
222            self.fragment_custom_native_function,
223            _,
224        ) = NativeFunction.from_yaml(
225            {"func": "quantized_decomposed::func() -> bool"},
226            loc=Location(__file__, 1),
227            valid_tags=set(),
228        )
229
230    def test_default_namespace_schema_registration_code_valid(self) -> None:
231        native_functions = [DEFAULT_NATIVE_FUNCTION]
232        registrations, _ = get_native_function_schema_registrations(
233            native_functions=native_functions,
234            schema_selector=self.selector,
235        )
236        self.assertEqual(registrations, ['m.def("func() -> bool", {});\n'])
237
238    def test_custom_namespace_schema_registration_code_valid(self) -> None:
239        _, registrations = get_native_function_schema_registrations(
240            native_functions=[self.custom_native_function],
241            schema_selector=self.selector,
242        )
243        self.assertEqual(
244            registrations,
245            """
246TORCH_LIBRARY(custom, m) {
247  m.def("func() -> bool", {});
248
249};""",
250        )
251
252    def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None:
253        """Sometimes we want to extend an existing namespace, for example quantized
254        namespace, which is already defined in native/quantized/library.cpp
255        """
256        _, registrations = get_native_function_schema_registrations(
257            native_functions=[self.fragment_custom_native_function],
258            schema_selector=self.selector,
259        )
260        self.assertEqual(
261            registrations,
262            """
263TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) {
264  m.def("func() -> bool", {});
265
266};""",
267        )
268
269    def test_mixed_namespace_schema_registration_code_valid(self) -> None:
270        (
271            aten_registrations,
272            custom_registrations,
273        ) = get_native_function_schema_registrations(
274            native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function],
275            schema_selector=self.selector,
276        )
277        self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
278        self.assertEqual(
279            custom_registrations,
280            """
281TORCH_LIBRARY(custom, m) {
282  m.def("func() -> bool", {});
283
284};""",
285        )
286
287    def test_3_namespaces_schema_registration_code_valid(self) -> None:
288        custom2_native_function, _ = NativeFunction.from_yaml(
289            {"func": "custom2::func() -> bool"},
290            loc=Location(__file__, 1),
291            valid_tags=set(),
292        )
293        (
294            aten_registrations,
295            custom_registrations,
296        ) = get_native_function_schema_registrations(
297            native_functions=[
298                DEFAULT_NATIVE_FUNCTION,
299                self.custom_native_function,
300                custom2_native_function,
301            ],
302            schema_selector=self.selector,
303        )
304        self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
305        self.assertEqual(
306            custom_registrations,
307            """
308TORCH_LIBRARY(custom, m) {
309  m.def("func() -> bool", {});
310
311};
312TORCH_LIBRARY(custom2, m) {
313  m.def("func() -> bool", {});
314
315};""",
316        )
317
318
319class TestGenNativeFunctionDeclaration(unittest.TestCase):
320    def setUp(self) -> None:
321        self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
322            {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
323            loc=Location(__file__, 1),
324            valid_tags=set(),
325        )
326        self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
327            {
328                "func": "op_2() -> bool",
329                "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
330            },
331            loc=Location(__file__, 1),
332            valid_tags=set(),
333        )
334
335        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
336            DispatchKey.CPU: {},
337            DispatchKey.QuantizedCPU: {},
338        }
339        BackendIndex.grow_index(backend_indices, op_1_backend_index)
340        BackendIndex.grow_index(backend_indices, op_2_backend_index)
341        self.backend_indices = {
342            k: BackendIndex(
343                dispatch_key=k,
344                use_out_as_primary=True,
345                external=False,
346                device_guard=False,
347                index=backend_indices[k],
348            )
349            for k in backend_indices
350        }
351
352    def test_native_function_declaration_1_op_2_ns_error(self) -> None:
353        with self.assertRaises(AssertionError):
354            get_native_function_declarations(
355                grouped_native_functions=[
356                    self.op_1_native_function,
357                    self.op_2_native_function,
358                ],
359                backend_indices=self.backend_indices,
360                native_function_decl_gen=dest.compute_native_function_declaration,
361            )
362
363    def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
364        self.assertIsInstance(self.op_1_native_function, NativeFunction)
365        declaration = get_native_function_declarations(
366            grouped_native_functions=[
367                self.op_1_native_function,
368            ],
369            backend_indices=self.backend_indices,
370            native_function_decl_gen=dest.compute_native_function_declaration,
371        )
372        target = """
373namespace at {
374namespace native {
375TORCH_API bool kernel_1();
376} // namespace native
377} // namespace at
378        """
379        self.assertEqual("\n".join(declaration), target)
380
381
382# Test for native_function_generation
383class TestNativeFunctionGeneratrion(unittest.TestCase):
384    def setUp(self) -> None:
385        self.native_functions: list[NativeFunction] = []
386        self.backend_indices: dict[
387            DispatchKey, dict[OperatorName, BackendMetadata]
388        ] = defaultdict(dict)
389        yaml_entry = """
390- func: op(Tensor self) -> Tensor
391  dispatch:
392    CompositeExplicitAutograd: op
393  autogen: op.out
394        """
395        es = yaml.load(yaml_entry, Loader=LineLoader)
396        self.one_return_func, m = NativeFunction.from_yaml(
397            es[0], loc=Location(__file__, 1), valid_tags=set()
398        )
399
400        BackendIndex.grow_index(self.backend_indices, m)
401
402        self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml(
403            {
404                "func": "op_2() -> (Tensor, Tensor)",
405                "dispatch": {"CPU": "kernel_1"},
406                "autogen": "op_2.out",
407            },
408            loc=Location(__file__, 1),
409            valid_tags=set(),
410        )
411        BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
412
413    def test_functional_variant_autogen_out_variant(self) -> None:
414        native_functions = [self.one_return_func]
415        add_generated_native_functions(native_functions, self.backend_indices)
416        self.assertEqual(len(native_functions), 2)
417        self.assertEqual(
418            str(native_functions[1].func),
419            "op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)",
420        )
421        op_name = native_functions[1].func.name
422        backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
423            op_name
424        ]
425        self.assertEqual(backend_metadata.kernel, "op_out")
426
427    def test_functional_variant_autogen_out_variant_two_returns(self) -> None:
428        native_functions = [self.two_returns_func]
429        add_generated_native_functions(native_functions, self.backend_indices)
430        self.assertEqual(len(native_functions), 2)
431        self.assertEqual(
432            str(native_functions[1].func),
433            "op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
434        )
435        op_name = native_functions[1].func.name
436        backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
437            op_name
438        ]
439        self.assertEqual(backend_metadata.kernel, "op_2_out")
440
441
442# Test for static_dispatch
443class TestStaticDispatchGeneratrion(unittest.TestCase):
444    def setUp(self) -> None:
445        self.backend_indices: dict[
446            DispatchKey, dict[OperatorName, BackendMetadata]
447        ] = defaultdict(dict)
448        yaml_entry = """
449- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
450  dispatch:
451    CompositeExplicitAutograd: op
452        """
453        es = yaml.load(yaml_entry, Loader=LineLoader)
454        self.one_return_func, m = NativeFunction.from_yaml(
455            es[0], loc=Location(__file__, 1), valid_tags=set()
456        )
457
458        BackendIndex.grow_index(self.backend_indices, m)
459        dispatch_key = DispatchKey.CompositeExplicitAutograd
460        self.assertTrue(dispatch_key in self.backend_indices)
461        self.indices = [
462            BackendIndex(
463                dispatch_key=dispatch_key,
464                use_out_as_primary=True,
465                external=False,
466                device_guard=False,
467                index=self.backend_indices[dispatch_key],
468            )
469        ]
470
471    def test_op_with_1_backend_generates_static_dispatch(self) -> None:
472        disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
473        with native_function_manager(self.one_return_func):
474            out = static_dispatch(
475                sig=disp_sig,
476                f=self.one_return_func,
477                backend_indices=self.indices,
478            )
479        self.assertEqual(
480            out, "return at::compositeexplicitautograd::op_out(out, self);"
481        )
482
483    def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
484        sig_group = CppSignatureGroup.from_native_function(
485            self.one_return_func,
486            method=False,
487            fallback_binding=self.one_return_func.manual_cpp_binding,
488        )
489        # cpp signature puts out at the front
490        with native_function_manager(self.one_return_func):
491            out = static_dispatch(
492                sig=sig_group.signature,
493                f=self.one_return_func,
494                backend_indices=self.indices,
495            )
496        self.assertEqual(
497            out, "return at::compositeexplicitautograd::op_out(out, self);"
498        )
499
500
501# Represents the most basic NativeFunction. Use dataclasses.replace()
502# to edit for use.
503DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
504    {"func": "func() -> bool"},
505    loc=Location(__file__, 1),
506    valid_tags=set(),
507)
508
509
510if __name__ == "__main__":
511    unittest.main()
512