1import torch 2import torch.nn.functional as F 3from torch.testing._internal.common_nn import wrap_functional 4 5 6""" 7`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API 8parity test harness works for `torch.nn.functional` functions. 9 10When `has_parity=true` is passed to `sample_functional`, behavior of `sample_functional` 11is the same as the C++ equivalent. 12 13When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional` 14is different from the C++ equivalent. 15""" 16 17 18def sample_functional(x, has_parity): 19 if has_parity: 20 return x * 2 21 else: 22 return x * 4 23 24 25torch.nn.functional.sample_functional = sample_functional 26 27SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n 28namespace torch { 29namespace nn { 30namespace functional { 31 32struct C10_EXPORT SampleFunctionalFuncOptions { 33 SampleFunctionalFuncOptions(bool has_parity) : has_parity_(has_parity) {} 34 35 TORCH_ARG(bool, has_parity); 36}; 37 38Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) { 39 return x * 2; 40} 41 42} // namespace functional 43} // namespace nn 44} // namespace torch 45""" 46 47functional_tests = [ 48 dict( 49 constructor=wrap_functional(F.sample_functional, has_parity=True), 50 cpp_options_args="F::SampleFunctionalFuncOptions(true)", 51 input_size=(1, 2, 3), 52 fullname="sample_functional_has_parity", 53 has_parity=True, 54 ), 55 dict( 56 constructor=wrap_functional(F.sample_functional, has_parity=False), 57 cpp_options_args="F::SampleFunctionalFuncOptions(false)", 58 input_size=(1, 2, 3), 59 fullname="sample_functional_no_parity", 60 has_parity=False, 61 ), 62 # This is to test that setting the `test_cpp_api_parity=False` flag skips 63 # the C++ API parity test accordingly (otherwise this test would run and 64 # throw a parity error). 65 dict( 66 constructor=wrap_functional(F.sample_functional, has_parity=False), 67 cpp_options_args="F::SampleFunctionalFuncOptions(false)", 68 input_size=(1, 2, 3), 69 fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED", 70 test_cpp_api_parity=False, 71 ), 72] 73