1# The purpose of this test is to check that we have implementation parity between 2# a Python `torch.nn.functional` function and its corresponding C++ `torch::nn::functional` 3# function. Concretely, this test does the following: 4# 5# 1. Get a test params dict from common_nn.py, run forward pass on the Python functional 6# created using the test params. 7# 8# 2. Serialize the Python functional's forward input arguments, deserialize them 9# in C++ and use them as input for the C++ functional's forward pass. 10# 11# 3. Run the forward pass on the C++ functional, and serialize the C++ functional's 12# forward output. 13# 14# 4. Compare Python/C++ functional's forward output. If they are the same, then we 15# have implementation parity between Python/C++ module. 16 17import os 18import pprint 19import re 20import tempfile 21from string import Template 22 23import torch 24from cpp_api_parity.sample_functional import SAMPLE_FUNCTIONAL_CPP_SOURCE 25from cpp_api_parity.utils import ( 26 add_test, 27 compile_cpp_code_inline, 28 compute_arg_dict, 29 compute_cpp_args_construction_stmts_and_forward_arg_symbols, 30 compute_temp_file_path, 31 decorate_test_fn, 32 generate_error_msg, 33 is_torch_nn_functional_test, 34 move_python_tensors_to_device, 35 serialize_arg_dict_as_script_module, 36 set_python_tensors_requires_grad, 37 TORCH_NN_COMMON_TEST_HARNESS, 38 TorchNNFunctionalTestParams, 39 try_remove_folder, 40) 41 42 43# Expected substitutions: 44# 45# ${functional_variant_name} (e.g. `BCELoss_no_reduce`) 46# ${cpp_args_construction_stmts} 47# ${cpp_function_call} 48TORCH_NN_FUNCTIONAL_TEST_FORWARD = Template( 49 """ 50void ${functional_variant_name}_test_forward( 51 const std::string& arg_dict_file_path, 52 const std::string& forward_output_file_path) { 53 pybind11::gil_scoped_release no_gil; 54 55 namespace F = torch::nn::functional; 56 57 // Declare arguments 58 auto arg_dict = load_dict_from_file(arg_dict_file_path); 59 ${cpp_args_construction_stmts}; 60 61 // Some functionals (such as `F::rrelu`) create random tensors in their call path. 62 // To make sure the random tensors created are the same in Python/C++, we need 63 // to set the RNG seed manually. 64 torch::manual_seed(0); 65 66 // Run function with arguments 67 auto cpp_output = ${cpp_function_call}; 68 69 // Save the output into a file to be compared in Python later 70 write_ivalue_to_file(torch::IValue(cpp_output), forward_output_file_path); 71} 72""" 73) 74 75 76def run_forward(unit_test_class, test_params): 77 device = test_params.device 78 79 inputs = set_python_tensors_requires_grad( 80 move_python_tensors_to_device( 81 [arg_value for _, arg_value in test_params.arg_dict["input"]], device 82 ) 83 ) 84 inputs += move_python_tensors_to_device( 85 [arg_value for _, arg_value in test_params.arg_dict["target"]], device 86 ) 87 inputs += move_python_tensors_to_device( 88 [arg_value for _, arg_value in test_params.arg_dict["extra_args"]], device 89 ) 90 91 # Some functionals (such as `F.rrelu`) create random tensors in their call path. 92 # To make sure the random tensors created are the same in Python/C++, we need 93 # to set the RNG seed manually. 94 torch.manual_seed(0) 95 python_output = test_params.test_instance.constructor()(*inputs) 96 97 return python_output 98 99 100def test_forward(unit_test_class, test_params): 101 functional_variant_name = test_params.functional_variant_name 102 cpp_tmp_folder = test_params.cpp_tmp_folder 103 # Remove the temporary folder if it exists already 104 try_remove_folder(cpp_tmp_folder) 105 os.mkdir(cpp_tmp_folder) 106 107 # Run forward on Python functional 108 python_output = run_forward(unit_test_class, test_params) 109 110 # Save Python arguments to be used from C++ function 111 arg_dict_file_path = compute_temp_file_path( 112 cpp_tmp_folder, functional_variant_name, "arg_dict" 113 ) 114 serialize_arg_dict_as_script_module(test_params.arg_dict).save(arg_dict_file_path) 115 116 cpp_test_name = f"{test_params.functional_variant_name}_test_forward" 117 cpp_test_fn = getattr( 118 unit_test_class.functional_impl_check_cpp_module, cpp_test_name 119 ) 120 121 def run_cpp_test_fn_and_check_output(): 122 forward_output_file_path = compute_temp_file_path( 123 cpp_tmp_folder, functional_variant_name, "forward_output" 124 ) 125 126 cpp_test_fn(arg_dict_file_path, forward_output_file_path) 127 cpp_output = torch.load(forward_output_file_path) 128 129 # Check that forward outputs are equal 130 unit_test_class.assertEqual( 131 python_output, 132 cpp_output, 133 msg=generate_error_msg("forward output", cpp_output, python_output), 134 ) 135 136 run_cpp_test_fn_and_check_output() 137 138 # Remove temporary folder that stores C++ outputs 139 try_remove_folder(cpp_tmp_folder) 140 141 142def compute_functional_name(test_params_dict): 143 def camel_case_to_snake_case(camel_case_str): 144 return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_str).lower() 145 146 if "cpp_options_args" in test_params_dict: 147 # Expected format for `cpp_options_args`: `F::FunctionalFuncOptions(...)` 148 # Example output: `binary_cross_entropy` 149 return camel_case_to_snake_case( 150 test_params_dict["cpp_options_args"] 151 .split("(")[0] 152 .replace("F::", "") 153 .replace("FuncOptions", "") 154 ) 155 elif "cpp_function_call" in test_params_dict: 156 # Expected format for `cpp_function_call`: `F::functional_name(...)` 157 # Example output: `binary_cross_entropy` 158 return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "") 159 else: 160 raise RuntimeError( 161 f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 162 ) 163 164 165def compute_cpp_function_call(test_params_dict, arg_dict, functional_name): 166 if "cpp_function_call" in test_params_dict: 167 return test_params_dict["cpp_function_call"] 168 elif "cpp_options_args" in test_params_dict: 169 cpp_forward_args_symbols = [ 170 arg_name 171 for arg_name, _ in arg_dict["input"] 172 + arg_dict["target"] 173 + arg_dict["extra_args"] 174 ] 175 return "F::{}({}, {})".format( 176 functional_name, 177 ", ".join(cpp_forward_args_symbols), 178 test_params_dict["cpp_options_args"], 179 ) 180 else: 181 raise RuntimeError( 182 f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 183 ) 184 185 186def process_test_params_for_functional(test_params_dict, device, test_instance_class): 187 test_instance = test_instance_class(**test_params_dict) 188 functional_name = compute_functional_name(test_params_dict) 189 assert test_instance.get_name().startswith("test_") 190 # Example output: `BCELoss_no_reduce_cuda` 191 functional_variant_name = test_instance.get_name()[5:] + ( 192 ("_" + device) if device != "cpu" else "" 193 ) 194 arg_dict = compute_arg_dict(test_params_dict, test_instance) 195 196 return TorchNNFunctionalTestParams( 197 functional_name=functional_name, 198 functional_variant_name=functional_variant_name, 199 test_instance=test_instance, 200 cpp_function_call=compute_cpp_function_call( 201 test_params_dict, arg_dict, functional_name 202 ), 203 arg_dict=arg_dict, 204 has_parity=test_params_dict.get("has_parity", True), 205 device=device, 206 cpp_tmp_folder=tempfile.mkdtemp(), 207 ) 208 209 210def write_test_to_test_class( 211 unit_test_class, test_params_dict, test_instance_class, parity_table, devices 212): 213 assert is_torch_nn_functional_test(test_params_dict) 214 215 assert ( 216 "cpp_options_args" in test_params_dict 217 or "cpp_function_call" in test_params_dict 218 ), ( 219 "To enable C++ API parity test, " 220 f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950 221 "If you are interested in adding the C++ API parity test, please see:\n" 222 "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n" 223 "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this." 224 ) 225 226 assert not ( 227 "cpp_options_args" in test_params_dict 228 and "cpp_function_call" in test_params_dict 229 ), ( 230 "Only one of `cpp_options_args` and `cpp_function_call` entries " 231 f"should be present in test params dict:\n{pprint.pformat(test_params_dict)}" 232 ) 233 234 functional_name = compute_functional_name(test_params_dict) 235 236 assert hasattr( 237 torch.nn.functional, functional_name 238 ), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950 239 240 functional_full_name = "F::" + functional_name 241 242 assert functional_full_name in parity_table["torch::nn::functional"], ( 243 f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. " 244 f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" 245 ) 246 247 for device in devices: 248 test_params = process_test_params_for_functional( 249 test_params_dict=test_params_dict, 250 device=device, 251 test_instance_class=test_instance_class, 252 ) 253 try_remove_folder(test_params.cpp_tmp_folder) 254 unit_test_name = ( 255 f"test_torch_nn_functional_{test_params.functional_variant_name}" 256 ) 257 unit_test_class.functional_test_params_map[unit_test_name] = test_params 258 259 def test_fn(self): 260 test_forward( 261 unit_test_class=self, 262 test_params=unit_test_class.functional_test_params_map[ 263 self._testMethodName 264 ], 265 ) 266 267 test_fn = decorate_test_fn( 268 test_fn=test_fn, 269 test_cuda=test_params_dict.get("test_cuda", True), 270 has_impl_parity=parity_table["torch::nn::functional"][functional_full_name][ 271 0 272 ] 273 and test_params_dict.get("has_parity", True), 274 device=device, 275 ) 276 277 add_test(unit_test_class, unit_test_name, test_fn) 278 279 280def generate_test_cpp_sources(test_params, template): 281 ( 282 cpp_args_construction_stmts, 283 _, 284 ) = compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params) 285 286 test_cpp_sources = template.substitute( 287 functional_variant_name=test_params.functional_variant_name, 288 cpp_args_construction_stmts=";\n ".join(cpp_args_construction_stmts), 289 cpp_function_call=test_params.cpp_function_call, 290 ) 291 return test_cpp_sources 292 293 294# Build all C++ tests together, instead of once per test. 295def build_cpp_tests(unit_test_class, print_cpp_source=False): 296 assert len(unit_test_class.functional_test_params_map) > 0 297 cpp_sources = TORCH_NN_COMMON_TEST_HARNESS + SAMPLE_FUNCTIONAL_CPP_SOURCE 298 functions = [] 299 for test_params in unit_test_class.functional_test_params_map.values(): 300 cpp_sources += generate_test_cpp_sources( 301 test_params=test_params, template=TORCH_NN_FUNCTIONAL_TEST_FORWARD 302 ) 303 functions.append(f"{test_params.functional_variant_name}_test_forward") 304 if print_cpp_source: 305 print(cpp_sources) 306 307 cpp_module = compile_cpp_code_inline( 308 name="functional_impl_check", cpp_sources=cpp_sources, functions=functions 309 ) 310 unit_test_class.functional_impl_check_cpp_module = cpp_module 311