xref: /aosp_15_r20/external/pytorch/test/cpp_api_parity/functional_impl_check.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# The purpose of this test is to check that we have implementation parity between
2# a Python `torch.nn.functional` function and its corresponding C++ `torch::nn::functional`
3# function. Concretely, this test does the following:
4#
5# 1. Get a test params dict from common_nn.py, run forward pass on the Python functional
6# created using the test params.
7#
8# 2. Serialize the Python functional's forward input arguments, deserialize them
9# in C++ and use them as input for the C++ functional's forward pass.
10#
11# 3. Run the forward pass on the C++ functional, and serialize the C++ functional's
12# forward output.
13#
14# 4. Compare Python/C++ functional's forward output. If they are the same, then we
15# have implementation parity between Python/C++ module.
16
17import os
18import pprint
19import re
20import tempfile
21from string import Template
22
23import torch
24from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE
25from cpp_api_parity.utils import (
26    add_test,
27    compile_cpp_code_inline,
28    compute_arg_dict,
29    compute_cpp_args_construction_stmts_and_forward_arg_symbols,
30    compute_temp_file_path,
31    decorate_test_fn,
32    generate_error_msg,
33    is_torch_nn_functional_test,
34    move_python_tensors_to_device,
35    serialize_arg_dict_as_script_module,
36    set_python_tensors_requires_grad,
37    TORCH_NN_COMMON_TEST_HARNESS,
38    TorchNNFunctionalTestParams,
39    try_remove_folder,
40)
41
42
43# Expected substitutions:
44#
45# ${functional_variant_name}  (e.g. `BCELoss_no_reduce`)
46# ${cpp_args_construction_stmts}
47# ${cpp_function_call}
48TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template(
49    """
50void ${functional_variant_name}_test_forward(
51    const std::string& arg_dict_file_path,
52    const std::string& forward_output_file_path) {
53  pybind11::gil_scoped_release no_gil;
54
55  namespace F = torch::nn::functional;
56
57  // Declare arguments
58  auto arg_dict = load_dict_from_file(arg_dict_file_path);
59  ${cpp_args_construction_stmts};
60
61  // Some functionals (such as `F::rrelu`) create random tensors in their call path.
62  // To make sure the random tensors created are the same in Python/C++, we need
63  // to set the RNG seed manually.
64  torch::manual_seed(0);
65
66  // Run function with arguments
67  auto cpp_output = ${cpp_function_call};
68
69  // Save the output into a file to be compared in Python later
70  write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path);
71}
72"""
73)
74
75
76def run_forward(unit_test_class, test_params):
77    device = test_params.device
78
79    inputs = set_python_tensors_requires_grad(
80        move_python_tensors_to_device(
81            [arg_value for _, arg_value in test_params.arg_dict["input"]], device
82        )
83    )
84    inputs += move_python_tensors_to_device(
85        [arg_value for _, arg_value in test_params.arg_dict["target"]], device
86    )
87    inputs += move_python_tensors_to_device(
88        [arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device
89    )
90
91    # Some functionals (such as `F.rrelu`) create random tensors in their call path.
92    # To make sure the random tensors created are the same in Python/C++, we need
93    # to set the RNG seed manually.
94    torch.manual_seed(0)
95    python_output = test_params.test_instance.constructor()(*inputs)
96
97    return python_output
98
99
100def test_forward(unit_test_class, test_params):
101    functional_variant_name = test_params.functional_variant_name
102    cpp_tmp_folder = test_params.cpp_tmp_folder
103    # Remove the temporary folder if it exists already
104    try_remove_folder(cpp_tmp_folder)
105    os.mkdir(cpp_tmp_folder)
106
107    # Run forward on Python functional
108    python_output = run_forward(unit_test_class, test_params)
109
110    # Save Python arguments to be used from C++ function
111    arg_dict_file_path = compute_temp_file_path(
112        cpp_tmp_folder, functional_variant_name, "arg_dict"
113    )
114    serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path)
115
116    cpp_test_name = f"{test_params.functional_variant_name}_test_forward"
117    cpp_test_fn = getattr(
118        unit_test_class.functional_impl_check_cpp_module, cpp_test_name
119    )
120
121    def run_cpp_test_fn_and_check_output():
122        forward_output_file_path = compute_temp_file_path(
123            cpp_tmp_folder, functional_variant_name, "forward_output"
124        )
125
126        cpp_test_fn(arg_dict_file_path, forward_output_file_path)
127        cpp_output = torch.load(forward_output_file_path)
128
129        # Check that forward outputs are equal
130        unit_test_class.assertEqual(
131            python_output,
132            cpp_output,
133            msg=generate_error_msg("forward output", cpp_output, python_output),
134        )
135
136    run_cpp_test_fn_and_check_output()
137
138    # Remove temporary folder that stores C++ outputs
139    try_remove_folder(cpp_tmp_folder)
140
141
142def compute_functional_name(test_params_dict):
143    def camel_case_to_snake_case(camel_case_str):
144        return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_str).lower()
145
146    if "cpp_options_args" in test_params_dict:
147        # Expected format for `cpp_options_args`: `F::FunctionalFuncOptions(...)`
148        # Example output: `binary_cross_entropy`
149        return camel_case_to_snake_case(
150            test_params_dict["cpp_options_args"]
151            .split("(")[0]
152            .replace("F::", "")
153            .replace("FuncOptions", "")
154        )
155    elif "cpp_function_call" in test_params_dict:
156        # Expected format for `cpp_function_call`: `F::functional_name(...)`
157        # Example output: `binary_cross_entropy`
158        return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
159    else:
160        raise RuntimeError(
161            f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}"  # noqa: B950
162        )
163
164
165def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
166    if "cpp_function_call" in test_params_dict:
167        return test_params_dict["cpp_function_call"]
168    elif "cpp_options_args" in test_params_dict:
169        cpp_forward_args_symbols = [
170            arg_name
171            for arg_name, _ in arg_dict["input"]
172            + arg_dict["target"]
173            + arg_dict["extra_args"]
174        ]
175        return "F::{}({}, {})".format(
176            functional_name,
177            ", ".join(cpp_forward_args_symbols),
178            test_params_dict["cpp_options_args"],
179        )
180    else:
181        raise RuntimeError(
182            f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}"  # noqa: B950
183        )
184
185
186def process_test_params_for_functional(test_params_dict, device, test_instance_class):
187    test_instance = test_instance_class(**test_params_dict)
188    functional_name = compute_functional_name(test_params_dict)
189    assert test_instance.get_name().startswith("test_")
190    # Example output: `BCELoss_no_reduce_cuda`
191    functional_variant_name = test_instance.get_name()[5:] + (
192        ("_" + device) if device != "cpu" else ""
193    )
194    arg_dict = compute_arg_dict(test_params_dict, test_instance)
195
196    return TorchNNFunctionalTestParams(
197        functional_name=functional_name,
198        functional_variant_name=functional_variant_name,
199        test_instance=test_instance,
200        cpp_function_call=compute_cpp_function_call(
201            test_params_dict, arg_dict, functional_name
202        ),
203        arg_dict=arg_dict,
204        has_parity=test_params_dict.get("has_parity", True),
205        device=device,
206        cpp_tmp_folder=tempfile.mkdtemp(),
207    )
208
209
210def write_test_to_test_class(
211    unit_test_class, test_params_dict, test_instance_class, parity_table, devices
212):
213    assert is_torch_nn_functional_test(test_params_dict)
214
215    assert (
216        "cpp_options_args" in test_params_dict
217        or "cpp_function_call" in test_params_dict
218    ), (
219        "To enable C++ API parity test, "
220        f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n"  # noqa: B950
221        "If you are interested in adding the C++ API parity test, please see:\n"
222        "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
223        "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
224    )
225
226    assert not (
227        "cpp_options_args" in test_params_dict
228        and "cpp_function_call" in test_params_dict
229    ), (
230        "Only one of `cpp_options_args` and `cpp_function_call` entries "
231        f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}"
232    )
233
234    functional_name = compute_functional_name(test_params_dict)
235
236    assert hasattr(
237        torch.nn.functional, functional_name
238    ), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)"  # noqa: B950
239
240    functional_full_name = "F::" + functional_name
241
242    assert functional_full_name in parity_table["torch::nn::functional"], (
243        f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
244        f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
245    )
246
247    for device in devices:
248        test_params = process_test_params_for_functional(
249            test_params_dict=test_params_dict,
250            device=device,
251            test_instance_class=test_instance_class,
252        )
253        try_remove_folder(test_params.cpp_tmp_folder)
254        unit_test_name = (
255            f"test_torch_nn_functional_{test_params.functional_variant_name}"
256        )
257        unit_test_class.functional_test_params_map[unit_test_name] = test_params
258
259        def test_fn(self):
260            test_forward(
261                unit_test_class=self,
262                test_params=unit_test_class.functional_test_params_map[
263                    self._testMethodName
264                ],
265            )
266
267        test_fn = decorate_test_fn(
268            test_fn=test_fn,
269            test_cuda=test_params_dict.get("test_cuda", True),
270            has_impl_parity=parity_table["torch::nn::functional"][functional_full_name][
271                0
272            ]
273            and test_params_dict.get("has_parity", True),
274            device=device,
275        )
276
277        add_test(unit_test_class, unit_test_name, test_fn)
278
279
280def generate_test_cpp_sources(test_params, template):
281    (
282        cpp_args_construction_stmts,
283        _,
284    ) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params)
285
286    test_cpp_sources = template.substitute(
287        functional_variant_name=test_params.functional_variant_name,
288        cpp_args_construction_stmts=";\n  ".join(cpp_args_construction_stmts),
289        cpp_function_call=test_params.cpp_function_call,
290    )
291    return test_cpp_sources
292
293
294# Build all C++ tests together, instead of once per test.
295def build_cpp_tests(unit_test_class, print_cpp_source=False):
296    assert len(unit_test_class.functional_test_params_map) > 0
297    cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE
298    functions = []
299    for test_params in unit_test_class.functional_test_params_map.values():
300        cpp_sources += generate_test_cpp_sources(
301            test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD
302        )
303        functions.append(f"{test_params.functional_variant_name}_test_forward")
304    if print_cpp_source:
305        print(cpp_sources)
306
307    cpp_module = compile_cpp_code_inline(
308        name="functional_impl_check", cpp_sources=cpp_sources, functions=functions
309    )
310    unit_test_class.functional_impl_check_cpp_module = cpp_module
311