1# Owner(s): ["module: unknown"] 2 3import platform 4from functools import partial 5from unittest import skipIf as skipif 6 7import torch 8from torch.testing._internal.common_device_type import ( 9 instantiate_device_type_tests, 10 OpDTypes, 11 ops, 12) 13from torch.testing._internal.common_methods_invocations import op_db 14from torch.testing._internal.common_utils import ( 15 IS_MACOS, 16 run_tests, 17 skipIfTorchInductor, 18 TestCase, 19 TestGradients, 20 unMarkDynamoStrictTest, 21) 22 23 24# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033 25# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The 26# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33 27if IS_MACOS: 28 torch.set_num_threads(1) 29 30# gradcheck requires double precision 31_gradcheck_ops = partial( 32 ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble] 33) 34 35 36@unMarkDynamoStrictTest 37class TestFwdGradients(TestGradients): 38 # Test that forward-over-reverse gradgrad is computed correctly 39 @_gradcheck_ops(op_db) 40 def test_fn_fwgrad_bwgrad(self, device, dtype, op): 41 self._skip_helper(op, device, dtype) 42 43 if op.supports_fwgrad_bwgrad: 44 self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") 45 else: 46 err_msg = r"Trying to use forward AD with .* that does not support it" 47 hint_msg = ( 48 "Running forward-over-backward gradgrad for an OP that has does not support it did not " 49 "raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True." 50 ) 51 with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 52 self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") 53 54 def _forward_grad_helper(self, device, dtype, op, variant, is_inplace): 55 # TODO: clean up how attributes are passed to gradcheck from OpInfos 56 def call_grad_test_helper(): 57 check_batched_forward_grad = ( 58 op.check_batched_forward_grad and not is_inplace 59 ) or (op.check_inplace_batched_forward_grad and is_inplace) 60 self._grad_test_helper( 61 device, 62 dtype, 63 op, 64 variant, 65 check_forward_ad=True, 66 check_backward_ad=False, 67 check_batched_grad=False, 68 check_batched_forward_grad=check_batched_forward_grad, 69 ) 70 71 if op.supports_forward_ad: 72 call_grad_test_helper() 73 else: 74 err_msg = r"Trying to use forward AD with .* that does not support it" 75 hint_msg = ( 76 "Running forward AD for an OP that has does not support it did not " 77 "raise any error. If your op supports forward AD, you should set supports_forward_ad=True" 78 ) 79 with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 80 call_grad_test_helper() 81 82 @_gradcheck_ops(op_db) 83 @skipif( 84 platform.machine() == "s390x", 85 reason="Different precision of openblas functions: https://github.com/OpenMathLib/OpenBLAS/issues/4194", 86 ) 87 def test_forward_mode_AD(self, device, dtype, op): 88 self._skip_helper(op, device, dtype) 89 90 self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False) 91 92 @_gradcheck_ops(op_db) 93 @skipIfTorchInductor("to be fixed") 94 def test_inplace_forward_mode_AD(self, device, dtype, op): 95 self._skip_helper(op, device, dtype) 96 97 if not op.inplace_variant or not op.supports_inplace_autograd: 98 self.skipTest("Skipped! Operation does not support inplace autograd.") 99 100 self._forward_grad_helper( 101 device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True 102 ) 103 104 105instantiate_device_type_tests(TestFwdGradients, globals()) 106 107if __name__ == "__main__": 108 TestCase._default_dtype_check_enabled = True 109 run_tests() 110