xref: /aosp_15_r20/external/pytorch/test/cpp_api_parity/sample_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn.functional as F
3from torch.testing._internal.common_nn import wrap_functional
4
5
6"""
7`sample_functional` is used by `test_cpp_api_parity.py` to test that Python / C++ API
8parity test harness works for `torch.nn.functional` functions.
9
10When `has_parity=true` is passed to `sample_functional`, behavior of `sample_functional`
11is the same as the C++ equivalent.
12
13When `has_parity=false` is passed to `sample_functional`, behavior of `sample_functional`
14is different from the C++ equivalent.
15"""
16
17
18def sample_functional(x, has_parity):
19    if has_parity:
20        return x * 2
21    else:
22        return x * 4
23
24
25torch.nn.functional.sample_functional = sample_functional
26
27SAMPLE_FUNCTIONAL_CPP_SOURCE = """\n
28namespace torch {
29namespace nn {
30namespace functional {
31
32struct C10_EXPORT SampleFunctionalFuncOptions {
33  SampleFunctionalFuncOptions(bool has_parity) : has_parity_(has_parity) {}
34
35  TORCH_ARG(bool, has_parity);
36};
37
38Tensor sample_functional(Tensor x, SampleFunctionalFuncOptions options) {
39    return x * 2;
40}
41
42} // namespace functional
43} // namespace nn
44} // namespace torch
45"""
46
47functional_tests = [
48    dict(
49        constructor=wrap_functional(F.sample_functional, has_parity=True),
50        cpp_options_args="F::SampleFunctionalFuncOptions(true)",
51        input_size=(1, 2, 3),
52        fullname="sample_functional_has_parity",
53        has_parity=True,
54    ),
55    dict(
56        constructor=wrap_functional(F.sample_functional, has_parity=False),
57        cpp_options_args="F::SampleFunctionalFuncOptions(false)",
58        input_size=(1, 2, 3),
59        fullname="sample_functional_no_parity",
60        has_parity=False,
61    ),
62    # This is to test that setting the `test_cpp_api_parity=False` flag skips
63    # the C++ API parity test accordingly (otherwise this test would run and
64    # throw a parity error).
65    dict(
66        constructor=wrap_functional(F.sample_functional, has_parity=False),
67        cpp_options_args="F::SampleFunctionalFuncOptions(false)",
68        input_size=(1, 2, 3),
69        fullname="sample_functional_THIS_TEST_SHOULD_BE_SKIPPED",
70        test_cpp_api_parity=False,
71    ),
72]
73