xref: /aosp_15_r20/external/pytorch/test/test_ops_gradients.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3from functools import partial
4
5import torch
6from torch.testing._internal.common_device_type import (
7    instantiate_device_type_tests,
8    OpDTypes,
9    ops,
10)
11from torch.testing._internal.common_methods_invocations import op_db
12from torch.testing._internal.common_utils import (
13    run_tests,
14    TestCase,
15    TestGradients,
16    unMarkDynamoStrictTest,
17)
18from torch.testing._internal.custom_op_db import custom_op_db
19from torch.testing._internal.hop_db import hop_db
20
21
22# gradcheck requires double precision
23_gradcheck_ops = partial(
24    ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
25)
26
27
28@unMarkDynamoStrictTest
29class TestBwdGradients(TestGradients):
30    # Tests that gradients are computed correctly
31    @_gradcheck_ops(op_db + hop_db + custom_op_db)
32    def test_fn_grad(self, device, dtype, op):
33        # This is verified by test_dtypes in test_ops.py
34        if dtype not in op.supported_backward_dtypes(torch.device(device).type):
35            self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
36        else:
37            self._grad_test_helper(device, dtype, op, op.get_op())
38
39    # Method grad (and gradgrad, see below) tests are disabled since they're
40    #   costly and redundant with function grad (and gradgad) tests
41    # @_gradcheck_ops(op_db)
42    # def test_method_grad(self, device, dtype, op):
43    #     self._skip_helper(op, device, dtype)
44    #     self._grad_test_helper(device, dtype, op, op.get_method())
45
46    @_gradcheck_ops(op_db + custom_op_db)
47    def test_inplace_grad(self, device, dtype, op):
48        self._skip_helper(op, device, dtype)
49        if not op.inplace_variant:
50            self.skipTest("Op has no inplace variant!")
51
52        # Verifies an operation doesn't support inplace autograd if it claims not to
53        if not op.supports_inplace_autograd:
54            inplace = self._get_safe_inplace(op.get_inplace())
55            for sample in op.sample_inputs(device, dtype, requires_grad=True):
56                if sample.broadcasts_input:
57                    continue
58                with self.assertRaises(Exception):
59                    result = inplace(sample)
60                    result.sum().backward()
61        else:
62            self._grad_test_helper(
63                device, dtype, op, self._get_safe_inplace(op.get_inplace())
64            )
65
66    # Test that gradients of gradients are computed correctly
67    @_gradcheck_ops(op_db + hop_db + custom_op_db)
68    def test_fn_gradgrad(self, device, dtype, op):
69        self._skip_helper(op, device, dtype)
70        if not op.supports_gradgrad:
71            self.skipTest(
72                "Op claims it doesn't support gradgrad. This is not verified."
73            )
74        else:
75            self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
76
77    # Test that gradients of gradients are properly raising
78    @_gradcheck_ops(op_db + custom_op_db)
79    def test_fn_fail_gradgrad(self, device, dtype, op):
80        self._skip_helper(op, device, dtype)
81        if op.supports_gradgrad:
82            self.skipTest("Skipped! Operation does support gradgrad")
83
84        err_msg = r"derivative for .* is not implemented"
85        with self.assertRaisesRegex(RuntimeError, err_msg):
86            self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
87
88    # Method gradgrad (and grad, see above) tests are disabled since they're
89    #   costly and redundant with function gradgrad (and grad) tests
90    # @_gradcheck_ops(op_db)
91    # def test_method_gradgrad(self, device, dtype, op):
92    #     self._skip_helper(op, device, dtype)
93    #     self._gradgrad_test_helper(device, dtype, op, op.get_method())
94
95    @_gradcheck_ops(op_db)
96    def test_inplace_gradgrad(self, device, dtype, op):
97        self._skip_helper(op, device, dtype)
98        if not op.inplace_variant or not op.supports_inplace_autograd:
99            self.skipTest("Skipped! Operation does not support inplace autograd.")
100        self._check_helper(
101            device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad"
102        )
103
104
105instantiate_device_type_tests(TestBwdGradients, globals())
106
107if __name__ == "__main__":
108    TestCase._default_dtype_check_enabled = True
109    run_tests()
110