1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport types 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerimport warnings 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.functional as autogradF 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 11*da0073e9SAndroid Build Coastguard Worker gradcheck, 12*da0073e9SAndroid Build Coastguard Worker gradgradcheck, 13*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 14*da0073e9SAndroid Build Coastguard Worker parametrize, 15*da0073e9SAndroid Build Coastguard Worker run_tests, 16*da0073e9SAndroid Build Coastguard Worker subtest, 17*da0073e9SAndroid Build Coastguard Worker TestCase, 18*da0073e9SAndroid Build Coastguard Worker) 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import LoggingTensor 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker# Utilities for parametrizing the tensor constructors used in autograd tests 23*da0073e9SAndroid Build Coastguard Worker# 24*da0073e9SAndroid Build Coastguard Worker# TODO: maybe move somewhere so other tests can also use 25*da0073e9SAndroid Build Coastguard Worker# 26*da0073e9SAndroid Build Coastguard Worker# NB: Not all factory functions included. A complete(?) list can be found here: 27*da0073e9SAndroid Build Coastguard Worker# https://pytorch.org/cppdocs/notes/tensor_creation.html 28*da0073e9SAndroid Build Coastguard Workerbase_ctors_dict = { 29*da0073e9SAndroid Build Coastguard Worker "ones": torch.ones, 30*da0073e9SAndroid Build Coastguard Worker "zeros": torch.zeros, 31*da0073e9SAndroid Build Coastguard Worker "randn": torch.randn, 32*da0073e9SAndroid Build Coastguard Worker "rand": torch.rand, 33*da0073e9SAndroid Build Coastguard Worker "tensor": torch.tensor, 34*da0073e9SAndroid Build Coastguard Worker} 35*da0073e9SAndroid Build Coastguard Workerbase_ctors = types.SimpleNamespace(**base_ctors_dict) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerdef wrap_with_logging_tensor(ctor): 39*da0073e9SAndroid Build Coastguard Worker def wrapper(*args, **kwargs): 40*da0073e9SAndroid Build Coastguard Worker requires_grad = kwargs.pop("requires_grad", False) 41*da0073e9SAndroid Build Coastguard Worker return LoggingTensor(ctor(*args, **kwargs), requires_grad=requires_grad) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker return wrapper 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerlogging_tensor_ctors_dict = { 47*da0073e9SAndroid Build Coastguard Worker k: wrap_with_logging_tensor(ctor) for (k, ctor) in base_ctors_dict.items() 48*da0073e9SAndroid Build Coastguard Worker} 49*da0073e9SAndroid Build Coastguard Workerlogging_tensor_ctors = types.SimpleNamespace(**logging_tensor_ctors_dict) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerbase_and_logging_tensor = parametrize( 52*da0073e9SAndroid Build Coastguard Worker "ctors", 53*da0073e9SAndroid Build Coastguard Worker [ 54*da0073e9SAndroid Build Coastguard Worker subtest(base_ctors, name="base_tensor"), 55*da0073e9SAndroid Build Coastguard Worker subtest(logging_tensor_ctors, name="logging_tensor"), 56*da0073e9SAndroid Build Coastguard Worker ], 57*da0073e9SAndroid Build Coastguard Worker) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard WorkerFIXME_base_and_xfail_logging_tensor = parametrize( 60*da0073e9SAndroid Build Coastguard Worker "ctors", 61*da0073e9SAndroid Build Coastguard Worker [ 62*da0073e9SAndroid Build Coastguard Worker subtest(base_ctors, name="base_tensor"), 63*da0073e9SAndroid Build Coastguard Worker subtest( 64*da0073e9SAndroid Build Coastguard Worker logging_tensor_ctors, 65*da0073e9SAndroid Build Coastguard Worker name="logging_tensor", 66*da0073e9SAndroid Build Coastguard Worker decorators=[unittest.expectedFailure], 67*da0073e9SAndroid Build Coastguard Worker ), 68*da0073e9SAndroid Build Coastguard Worker ], 69*da0073e9SAndroid Build Coastguard Worker) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker# NB: This is equivalent to having both @parametrize("vectorized", [True, False]) and 72*da0073e9SAndroid Build Coastguard Worker# FIXME_base_and_xfail_logging_tensor, except the non-vectorized logging_tensor case is 73*da0073e9SAndroid Build Coastguard Worker# actually expected to succeed 74*da0073e9SAndroid Build Coastguard WorkerFIXME_xfail_vectorized_logging_tensor = parametrize( 75*da0073e9SAndroid Build Coastguard Worker "vectorize,ctors", 76*da0073e9SAndroid Build Coastguard Worker [ 77*da0073e9SAndroid Build Coastguard Worker subtest((True, base_ctors), name="vectorized_base_tensor"), 78*da0073e9SAndroid Build Coastguard Worker subtest((False, base_ctors), name="base_tensor"), 79*da0073e9SAndroid Build Coastguard Worker subtest( 80*da0073e9SAndroid Build Coastguard Worker (True, logging_tensor_ctors), 81*da0073e9SAndroid Build Coastguard Worker name="vectorized_logging_tensor", 82*da0073e9SAndroid Build Coastguard Worker decorators=[unittest.expectedFailure], 83*da0073e9SAndroid Build Coastguard Worker ), 84*da0073e9SAndroid Build Coastguard Worker subtest((False, logging_tensor_ctors), name="logging_tensor"), 85*da0073e9SAndroid Build Coastguard Worker ], 86*da0073e9SAndroid Build Coastguard Worker) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workervectorized_logging_tensor = parametrize( 89*da0073e9SAndroid Build Coastguard Worker "vectorize,ctors", 90*da0073e9SAndroid Build Coastguard Worker [ 91*da0073e9SAndroid Build Coastguard Worker subtest((True, base_ctors), name="vectorized_base_tensor"), 92*da0073e9SAndroid Build Coastguard Worker subtest((False, base_ctors), name="base_tensor"), 93*da0073e9SAndroid Build Coastguard Worker subtest((True, logging_tensor_ctors), name="vectorized_logging_tensor"), 94*da0073e9SAndroid Build Coastguard Worker subtest((False, logging_tensor_ctors), name="logging_tensor"), 95*da0073e9SAndroid Build Coastguard Worker ], 96*da0073e9SAndroid Build Coastguard Worker) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerclass TestAutogradFunctional(TestCase): 100*da0073e9SAndroid Build Coastguard Worker def _assert_same_struct(self, res, base): 101*da0073e9SAndroid Build Coastguard Worker # base and res should be Tensors or tuple of Tensors with the same size 102*da0073e9SAndroid Build Coastguard Worker if isinstance(base, torch.Tensor): 103*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, torch.Tensor)) 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(base.size(), res.size()) 105*da0073e9SAndroid Build Coastguard Worker elif isinstance(base, tuple): 106*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, tuple)) 107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(base), len(res)) 108*da0073e9SAndroid Build Coastguard Worker for el_base, el_res in zip(base, res): 109*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_base, torch.Tensor)) 110*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_res, torch.Tensor)) 111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(el_base.size(), el_res.size()) 112*da0073e9SAndroid Build Coastguard Worker else: 113*da0073e9SAndroid Build Coastguard Worker # Wrong base 114*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 115*da0073e9SAndroid Build Coastguard Worker "The base given to `_assert_same_struct` doesn't have" 116*da0073e9SAndroid Build Coastguard Worker " the right structure." 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def _assert_interleaved_struct(self, res, base1, base2): 120*da0073e9SAndroid Build Coastguard Worker # base1 and base2 can be Tensors or tuples of Tensors. 121*da0073e9SAndroid Build Coastguard Worker # If they are tuples, res should be a tuple as well. 122*da0073e9SAndroid Build Coastguard Worker # The indexing works as follows for base1, base2 being 123*da0073e9SAndroid Build Coastguard Worker # - tuple, tuple: res[i][j][k][l] = (base1[i][k], base2[j][l]) 124*da0073e9SAndroid Build Coastguard Worker # - tuple, Tensor: res[i][k][l] = (base1[i][k], base2[l]) 125*da0073e9SAndroid Build Coastguard Worker # - Tensor, tuple: res[i][j][l] = (base1[i], base2[j][l]) 126*da0073e9SAndroid Build Coastguard Worker # - Tensor, Tensor: res[k][l] = (base1[k], base2[l]) 127*da0073e9SAndroid Build Coastguard Worker if isinstance(base1, torch.Tensor) and isinstance(base2, torch.Tensor): 128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, torch.Tensor)) 129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.size(), base1.size() + base2.size()) 130*da0073e9SAndroid Build Coastguard Worker elif isinstance(base1, tuple) and isinstance(base2, torch.Tensor): 131*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, tuple)) 132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), len(base1)) 133*da0073e9SAndroid Build Coastguard Worker for el_res, el_base1 in zip(res, base1): 134*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_res, torch.Tensor)) 135*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_base1, torch.Tensor)) 136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(el_res.size(), el_base1.size() + base2.size()) 137*da0073e9SAndroid Build Coastguard Worker elif isinstance(base1, torch.Tensor) and isinstance(base2, tuple): 138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, tuple)) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), len(base2)) 140*da0073e9SAndroid Build Coastguard Worker for el_res, el_base2 in zip(res, base2): 141*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_res, torch.Tensor)) 142*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_base2, torch.Tensor)) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(el_res.size(), base1.size() + el_base2.size()) 144*da0073e9SAndroid Build Coastguard Worker elif isinstance(base1, tuple) and isinstance(base2, tuple): 145*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(res, tuple)) 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), len(base1)) 147*da0073e9SAndroid Build Coastguard Worker for el_res, el_base1 in zip(res, base1): 148*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_res, tuple)) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(res), len(base2)) 150*da0073e9SAndroid Build Coastguard Worker for el_el_res, el_base2 in zip(el_res, base2): 151*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_el_res, torch.Tensor)) 152*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(el_base2, torch.Tensor)) 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 154*da0073e9SAndroid Build Coastguard Worker el_el_res.size(), el_base1.size() + el_base2.size() 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker else: 157*da0073e9SAndroid Build Coastguard Worker # Wrong bases 158*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 159*da0073e9SAndroid Build Coastguard Worker "The bases given to `_assert_interleaved_struct` don't have" 160*da0073e9SAndroid Build Coastguard Worker " the right structure." 161*da0073e9SAndroid Build Coastguard Worker ) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 164*da0073e9SAndroid Build Coastguard Worker def test_vjp_err_check(self, ctors): 165*da0073e9SAndroid Build Coastguard Worker def foo(a): 166*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def bar(a): 169*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 172*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(3) 173*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 174*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to vjp must be either a Tensor" 175*da0073e9SAndroid Build Coastguard Worker ): 176*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, (inp, 2), v) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 179*da0073e9SAndroid Build Coastguard Worker TypeError, "The outputs of the user-provided function given to vjp must" 180*da0073e9SAndroid Build Coastguard Worker ): 181*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(bar, inp, v) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 184*da0073e9SAndroid Build Coastguard Worker RuntimeError, 185*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the user-provided function returns", 186*da0073e9SAndroid Build Coastguard Worker ): 187*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 190*da0073e9SAndroid Build Coastguard Worker RuntimeError, "The given v should contain a single Tensor." 191*da0073e9SAndroid Build Coastguard Worker ): 192*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, (torch.ones_like(inp), torch.ones_like(inp))) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 195*da0073e9SAndroid Build Coastguard Worker RuntimeError, "v has invalid size: should be torch.Size" 196*da0073e9SAndroid Build Coastguard Worker ): 197*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v[:2]) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v)[1] 200*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res, inp) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 203*da0073e9SAndroid Build Coastguard Worker def test_vjp_err_check_strict(self, ctors): 204*da0073e9SAndroid Build Coastguard Worker def foo(a): 205*da0073e9SAndroid Build Coastguard Worker return a.detach() 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def bar(a): 208*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 209*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone() 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 212*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 213*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 214*da0073e9SAndroid Build Coastguard Worker RuntimeError, 215*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 216*da0073e9SAndroid Build Coastguard Worker ): 217*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v, strict=True) 218*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v, strict=False) 219*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 223*da0073e9SAndroid Build Coastguard Worker RuntimeError, 224*da0073e9SAndroid Build Coastguard Worker "The output of the user-provided function is independent of input 0", 225*da0073e9SAndroid Build Coastguard Worker ): 226*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(bar, inp, v, strict=True) 227*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(bar, inp, v, strict=False) 228*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker # The Jacobian does not depend on the input 232*da0073e9SAndroid Build Coastguard Worker def foo(a): 233*da0073e9SAndroid Build Coastguard Worker return a.clone() 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker inp.requires_grad_() 236*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 237*da0073e9SAndroid Build Coastguard Worker RuntimeError, 238*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function is independent of input 0.", 239*da0073e9SAndroid Build Coastguard Worker ): 240*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v, create_graph=True, strict=True) 241*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(foo, inp, v, create_graph=True, strict=False) 242*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1], v) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 246*da0073e9SAndroid Build Coastguard Worker def test_vjp_no_grad(self, ctors): 247*da0073e9SAndroid Build Coastguard Worker def reducer(x): 248*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 251*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4) 252*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 253*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs, v) 254*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 255*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 256*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad_() 259*da0073e9SAndroid Build Coastguard Worker v.requires_grad_() 260*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 261*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs, v, create_graph=True) 262*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 263*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 264*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 267*da0073e9SAndroid Build Coastguard Worker def test_vjp_output(self, ctors): 268*da0073e9SAndroid Build Coastguard Worker def reducer(x): 269*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 272*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4) 273*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs, v) 274*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 275*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 276*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 279*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(2), ctors.rand(2)) 282*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(2) 283*da0073e9SAndroid Build Coastguard Worker out, vjp_val = autogradF.vjp(adder, inputs, v) 284*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(vjp_val, inputs) 285*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn) 286*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vjp_val[0].grad_fn) 287*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vjp_val[1].grad_fn) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 290*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y, x + y 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(2), ctors.rand(2)) 293*da0073e9SAndroid Build Coastguard Worker v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0])) 294*da0073e9SAndroid Build Coastguard Worker out, vjp_val = autogradF.vjp(adder, inputs, v) 295*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(vjp_val, inputs) 296*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out[0].grad_fn) 297*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out[1].grad_fn) 298*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vjp_val[0].grad_fn) 299*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vjp_val[1].grad_fn) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 302*da0073e9SAndroid Build Coastguard Worker def test_vjp_scalar(self, ctors): 303*da0073e9SAndroid Build Coastguard Worker def reducer(x): 304*da0073e9SAndroid Build Coastguard Worker return x.sum() 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 307*da0073e9SAndroid Build Coastguard Worker v = ctors.ones([]) 308*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs, v) 309*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], v) 310*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs) 313*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], v) 314*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker def expander(x): 317*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(0).repeat(4) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 320*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4) 321*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(expander, inputs, v) 322*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], v) 323*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 326*da0073e9SAndroid Build Coastguard Worker def test_vjp_create_graph(self, ctors): 327*da0073e9SAndroid Build Coastguard Worker def reducer(x): 328*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(2, 2, dtype=torch.double) 331*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(2, dtype=torch.double) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad_() 334*da0073e9SAndroid Build Coastguard Worker v.requires_grad_() 335*da0073e9SAndroid Build Coastguard Worker res = autogradF.vjp(reducer, inputs, v, create_graph=True) 336*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 337*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 338*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker gradcheck( 341*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True), 342*da0073e9SAndroid Build Coastguard Worker (inputs, v), 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 345*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True), 346*da0073e9SAndroid Build Coastguard Worker (inputs, v), 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 350*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y, x * y 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker inputs = ( 353*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, dtype=torch.double, requires_grad=True), 354*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, dtype=torch.double, requires_grad=True), 355*da0073e9SAndroid Build Coastguard Worker ) 356*da0073e9SAndroid Build Coastguard Worker v = ( 357*da0073e9SAndroid Build Coastguard Worker ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 358*da0073e9SAndroid Build Coastguard Worker ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker gradcheck( 362*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[ 363*da0073e9SAndroid Build Coastguard Worker 1 364*da0073e9SAndroid Build Coastguard Worker ], 365*da0073e9SAndroid Build Coastguard Worker inputs + v, 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 368*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[ 369*da0073e9SAndroid Build Coastguard Worker 1 370*da0073e9SAndroid Build Coastguard Worker ], 371*da0073e9SAndroid Build Coastguard Worker inputs + v, 372*da0073e9SAndroid Build Coastguard Worker ) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def foo(*args): 375*da0073e9SAndroid Build Coastguard Worker x, y = args[:2] 376*da0073e9SAndroid Build Coastguard Worker v = args[2:] 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker x = x.cos() 379*da0073e9SAndroid Build Coastguard Worker val, grad = autogradF.vjp(adder, (x, y), v, create_graph=True) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker return ( 382*da0073e9SAndroid Build Coastguard Worker val[0].exp() 383*da0073e9SAndroid Build Coastguard Worker + val[1].exp() 384*da0073e9SAndroid Build Coastguard Worker + grad[0].exp() 385*da0073e9SAndroid Build Coastguard Worker + grad[1].exp() 386*da0073e9SAndroid Build Coastguard Worker + x.exp() 387*da0073e9SAndroid Build Coastguard Worker + y.exp() 388*da0073e9SAndroid Build Coastguard Worker ) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs + v) 391*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs + v) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 394*da0073e9SAndroid Build Coastguard Worker def test_jvp_err_check(self, ctors): 395*da0073e9SAndroid Build Coastguard Worker def foo(a): 396*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker def bar(a): 399*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 402*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 403*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 404*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to jvp must be either a Tensor" 405*da0073e9SAndroid Build Coastguard Worker ): 406*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, (inp, 2), v) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 409*da0073e9SAndroid Build Coastguard Worker TypeError, "The outputs of the user-provided function given to jvp must" 410*da0073e9SAndroid Build Coastguard Worker ): 411*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(bar, inp, v) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 414*da0073e9SAndroid Build Coastguard Worker RuntimeError, 415*da0073e9SAndroid Build Coastguard Worker "The vector v can only be None if the input to the user-provided function", 416*da0073e9SAndroid Build Coastguard Worker ): 417*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 420*da0073e9SAndroid Build Coastguard Worker RuntimeError, "The given v should contain a single Tensor." 421*da0073e9SAndroid Build Coastguard Worker ): 422*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, (v, v)) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 425*da0073e9SAndroid Build Coastguard Worker RuntimeError, "v has invalid size: should be torch.Size" 426*da0073e9SAndroid Build Coastguard Worker ): 427*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v[:2]) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v)[1] 430*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res, foo(inp)) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 433*da0073e9SAndroid Build Coastguard Worker def test_jvp_err_check_strict(self, ctors): 434*da0073e9SAndroid Build Coastguard Worker def foo(a): 435*da0073e9SAndroid Build Coastguard Worker return a.detach() 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker def bar(a): 438*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 439*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone() 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 442*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 443*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 444*da0073e9SAndroid Build Coastguard Worker RuntimeError, 445*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 446*da0073e9SAndroid Build Coastguard Worker ): 447*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v, strict=True) 448*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v, strict=False) 449*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 453*da0073e9SAndroid Build Coastguard Worker RuntimeError, 454*da0073e9SAndroid Build Coastguard Worker "The output of the user-provided function is independent of input 0", 455*da0073e9SAndroid Build Coastguard Worker ): 456*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(bar, inp, v, strict=True) 457*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(bar, inp, v, strict=False) 458*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker # The Jacobian does not depend on the input 462*da0073e9SAndroid Build Coastguard Worker def foo(a): 463*da0073e9SAndroid Build Coastguard Worker return a.clone() 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker inp.requires_grad_() 466*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 467*da0073e9SAndroid Build Coastguard Worker RuntimeError, 468*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function is independent of input 0.", 469*da0073e9SAndroid Build Coastguard Worker ): 470*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v, create_graph=True, strict=True) 471*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(foo, inp, v, create_graph=True, strict=False) 472*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1], v) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 476*da0073e9SAndroid Build Coastguard Worker def test_jvp_no_grad(self, ctors): 477*da0073e9SAndroid Build Coastguard Worker def reducer(x): 478*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 481*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 482*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 483*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(reducer, inputs, v) 484*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 485*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 486*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad_() 489*da0073e9SAndroid Build Coastguard Worker v.requires_grad_() 490*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 491*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(reducer, inputs, v, create_graph=True) 492*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 493*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 494*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 497*da0073e9SAndroid Build Coastguard Worker def test_jvp_output(self, ctors): 498*da0073e9SAndroid Build Coastguard Worker def reducer(x): 499*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 502*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 503*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(reducer, inputs, v) 504*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 505*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 506*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 509*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(2), ctors.rand(2)) 512*da0073e9SAndroid Build Coastguard Worker v = (ctors.ones(2), ctors.ones(2)) 513*da0073e9SAndroid Build Coastguard Worker out, jvp_val = autogradF.jvp(adder, inputs, v) 514*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(jvp_val, out) 515*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn) 516*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(jvp_val[0].grad_fn) 517*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(jvp_val[1].grad_fn) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 520*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y, x + y 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(2), ctors.rand(2)) 523*da0073e9SAndroid Build Coastguard Worker v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0])) 524*da0073e9SAndroid Build Coastguard Worker out, jvp_val = autogradF.jvp(adder, inputs, v) 525*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(jvp_val, out) 526*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out[0].grad_fn) 527*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out[1].grad_fn) 528*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(jvp_val[0].grad_fn) 529*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(jvp_val[1].grad_fn) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 532*da0073e9SAndroid Build Coastguard Worker def test_jvp_scalar(self, ctors): 533*da0073e9SAndroid Build Coastguard Worker def reducer(x): 534*da0073e9SAndroid Build Coastguard Worker return x.sum() 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 537*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 538*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(reducer, inputs, v) 539*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], ctors.zeros([])) 540*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker def expander(x): 543*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(0).repeat(4) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 546*da0073e9SAndroid Build Coastguard Worker v = ctors.ones([]) 547*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(expander, inputs, v) 548*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], ctors.zeros(4)) 549*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(expander, inputs) 552*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[0], ctors.zeros(4)) 553*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 556*da0073e9SAndroid Build Coastguard Worker def test_jvp_create_graph(self, ctors): 557*da0073e9SAndroid Build Coastguard Worker def reducer(x): 558*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=1) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(2, 2, dtype=torch.double) 561*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(2, 2, dtype=torch.double) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker inputs.requires_grad_() 564*da0073e9SAndroid Build Coastguard Worker v.requires_grad_() 565*da0073e9SAndroid Build Coastguard Worker res = autogradF.jvp(reducer, inputs, v, create_graph=True) 566*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], res[0]) 567*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 568*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker gradcheck( 571*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), 572*da0073e9SAndroid Build Coastguard Worker (inputs, v), 573*da0073e9SAndroid Build Coastguard Worker ) 574*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 575*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), 576*da0073e9SAndroid Build Coastguard Worker (inputs, v), 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 580*da0073e9SAndroid Build Coastguard Worker return 2 * x + 3 * y, x * y 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker inputs = ( 583*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, dtype=torch.double, requires_grad=True), 584*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, dtype=torch.double, requires_grad=True), 585*da0073e9SAndroid Build Coastguard Worker ) 586*da0073e9SAndroid Build Coastguard Worker v = ( 587*da0073e9SAndroid Build Coastguard Worker ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 588*da0073e9SAndroid Build Coastguard Worker ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 589*da0073e9SAndroid Build Coastguard Worker ) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker gradcheck( 592*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[ 593*da0073e9SAndroid Build Coastguard Worker 1 594*da0073e9SAndroid Build Coastguard Worker ], 595*da0073e9SAndroid Build Coastguard Worker inputs + v, 596*da0073e9SAndroid Build Coastguard Worker ) 597*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 598*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[ 599*da0073e9SAndroid Build Coastguard Worker 1 600*da0073e9SAndroid Build Coastguard Worker ], 601*da0073e9SAndroid Build Coastguard Worker inputs + v, 602*da0073e9SAndroid Build Coastguard Worker ) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def foo(*args): 605*da0073e9SAndroid Build Coastguard Worker x, y = args[:2] 606*da0073e9SAndroid Build Coastguard Worker v = args[2:] 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker x = x.cos() 609*da0073e9SAndroid Build Coastguard Worker val, grad = autogradF.jvp(adder, (x, y), v, create_graph=True) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker return ( 612*da0073e9SAndroid Build Coastguard Worker val[0].exp() 613*da0073e9SAndroid Build Coastguard Worker + val[1].exp() 614*da0073e9SAndroid Build Coastguard Worker + grad[0].exp() 615*da0073e9SAndroid Build Coastguard Worker + grad[1].exp() 616*da0073e9SAndroid Build Coastguard Worker + x.exp() 617*da0073e9SAndroid Build Coastguard Worker + y.exp() 618*da0073e9SAndroid Build Coastguard Worker ) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs + v) 621*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs + v) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def _test_construct_standard_basis_for(self, inputs): 624*da0073e9SAndroid Build Coastguard Worker numels = tuple(tensor.numel() for tensor in inputs) 625*da0073e9SAndroid Build Coastguard Worker results = autogradF._construct_standard_basis_for(inputs, numels) 626*da0073e9SAndroid Build Coastguard Worker for result, inp in zip(results, inputs): 627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, inp.dtype) 628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.device, inp.device) 629*da0073e9SAndroid Build Coastguard Worker results = torch.cat( 630*da0073e9SAndroid Build Coastguard Worker [result.to(device="cpu", dtype=torch.float) for result in results], dim=1 631*da0073e9SAndroid Build Coastguard Worker ) 632*da0073e9SAndroid Build Coastguard Worker expected = torch.eye(results[0].shape[0], dtype=torch.float) 633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(results, expected) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 636*da0073e9SAndroid Build Coastguard Worker def test_construct_standard_basis_for(self, ctors): 637*da0073e9SAndroid Build Coastguard Worker test_cases = [ 638*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2, 3),), 639*da0073e9SAndroid Build Coastguard Worker (ctors.randn(1),), 640*da0073e9SAndroid Build Coastguard Worker (ctors.randn([]),), 641*da0073e9SAndroid Build Coastguard Worker (ctors.randn(1), ctors.randn([]), ctors.randn([])), 642*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2), ctors.randn(3), ctors.randn([])), 643*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2), ctors.randn([]), ctors.randn(3)), 644*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2, 3), ctors.randn(3), ctors.randn(3, 4, 2)), 645*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2, dtype=torch.float64), ctors.randn(3, dtype=torch.float32)), 646*da0073e9SAndroid Build Coastguard Worker ] 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker for inputs in test_cases: 649*da0073e9SAndroid Build Coastguard Worker self._test_construct_standard_basis_for(inputs) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 652*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 653*da0073e9SAndroid Build Coastguard Worker def test_construct_standard_basis_for_cuda(self, ctors): 654*da0073e9SAndroid Build Coastguard Worker test_cases = [ 655*da0073e9SAndroid Build Coastguard Worker (ctors.randn(2), ctors.randn(3, device="cuda")), 656*da0073e9SAndroid Build Coastguard Worker (ctors.randn(3, device="cuda"), ctors.randn(2)), 657*da0073e9SAndroid Build Coastguard Worker ] 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker for inputs in test_cases: 660*da0073e9SAndroid Build Coastguard Worker self._test_construct_standard_basis_for(inputs) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker def _test_vectorize_raises_no_warnings(self, api, ctors): 663*da0073e9SAndroid Build Coastguard Worker # vmap is an experimental prototype. When someone calls torch.vmap, 664*da0073e9SAndroid Build Coastguard Worker # it raises a python warning. This test checks that 665*da0073e9SAndroid Build Coastguard Worker # autogradF.{jacobian, hessian} don't raise that experimental prototype 666*da0073e9SAndroid Build Coastguard Worker # warning; it is not nice for a public-facing API to raise a warning 667*da0073e9SAndroid Build Coastguard Worker # no matter how it is called. 668*da0073e9SAndroid Build Coastguard Worker def foo(a): 669*da0073e9SAndroid Build Coastguard Worker return (a**2).sum() 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(3) 672*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 673*da0073e9SAndroid Build Coastguard Worker result = api(foo, x, vectorize=True) 674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 0) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 677*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_raises_no_warnings(self, ctors): 678*da0073e9SAndroid Build Coastguard Worker return self._test_vectorize_raises_no_warnings(autogradF.jacobian, ctors) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 681*da0073e9SAndroid Build Coastguard Worker def test_hessian_vectorize_raises_no_warnings(self, ctors): 682*da0073e9SAndroid Build Coastguard Worker return self._test_vectorize_raises_no_warnings(autogradF.hessian, ctors) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker @parametrize("vectorize", [True, False]) 685*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 686*da0073e9SAndroid Build Coastguard Worker def test_jacobian_err_check(self, vectorize, ctors): 687*da0073e9SAndroid Build Coastguard Worker def foo(a): 688*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker def bar(a): 691*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 694*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 695*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to jacobian must be either a Tensor" 696*da0073e9SAndroid Build Coastguard Worker ): 697*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, (inp, 2), vectorize=vectorize) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 700*da0073e9SAndroid Build Coastguard Worker TypeError, 701*da0073e9SAndroid Build Coastguard Worker "The outputs of the user-provided function given to jacobian must", 702*da0073e9SAndroid Build Coastguard Worker ): 703*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(bar, inp, vectorize=vectorize) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, vectorize=vectorize) 706*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, foo(inp), inp) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 709*da0073e9SAndroid Build Coastguard Worker return b, 3 * a.narrow(0, 0, 3) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker inp = (ctors.rand(4), ctors.rand(5)) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, vectorize=vectorize) 714*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, foo(*inp), inp) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 717*da0073e9SAndroid Build Coastguard Worker def test_jacobian_err_check_strict(self, ctors): 718*da0073e9SAndroid Build Coastguard Worker def foo(a): 719*da0073e9SAndroid Build Coastguard Worker return a.detach() 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def bar(a): 722*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 723*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone() 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 726*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 727*da0073e9SAndroid Build Coastguard Worker RuntimeError, 728*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 729*da0073e9SAndroid Build Coastguard Worker ): 730*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, strict=True) 731*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, strict=False) 732*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, foo(inp), inp) 733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.abs().sum(), 0.0) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 736*da0073e9SAndroid Build Coastguard Worker RuntimeError, 737*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function is independent of input 0.", 738*da0073e9SAndroid Build Coastguard Worker ): 739*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(bar, inp, strict=True) 740*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(bar, inp, strict=False) 741*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, foo(inp), inp) 742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.abs().sum(), 0.0) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker # The Jacobian does not depend on the input 745*da0073e9SAndroid Build Coastguard Worker def foo(a): 746*da0073e9SAndroid Build Coastguard Worker return a.clone() 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker inp.requires_grad_() 749*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 750*da0073e9SAndroid Build Coastguard Worker RuntimeError, 751*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function is independent of input 0.", 752*da0073e9SAndroid Build Coastguard Worker ): 753*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, create_graph=True, strict=True) 754*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, create_graph=True, strict=False) 755*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.eye(4)) 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 759*da0073e9SAndroid Build Coastguard Worker def test_jacobian_err_check_strict_vectorize(self, ctors): 760*da0073e9SAndroid Build Coastguard Worker def foo(x): 761*da0073e9SAndroid Build Coastguard Worker return x 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 764*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not supported together"): 765*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(foo, inp, strict=True, vectorize=True) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 768*da0073e9SAndroid Build Coastguard Worker def test_jacobian_no_grad(self, ctors): 769*da0073e9SAndroid Build Coastguard Worker def exp_reducer(x): 770*da0073e9SAndroid Build Coastguard Worker return x.exp().sum(dim=1) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 773*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 774*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(exp_reducer, inputs) 775*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res.grad_fn) 776*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res, ctors.zeros(4, 4)) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 779*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(exp_reducer, inputs, create_graph=True) 780*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res.grad_fn) 781*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res, ctors.zeros(4, 4)) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker @vectorized_logging_tensor 784*da0073e9SAndroid Build Coastguard Worker def test_jacobian_output(self, vectorize, ctors): 785*da0073e9SAndroid Build Coastguard Worker def exp_reducer(x): 786*da0073e9SAndroid Build Coastguard Worker return x.exp().sum(dim=1) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 789*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(exp_reducer, inputs, vectorize=vectorize) 790*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, exp_reducer(inputs), inputs) 791*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res.grad_fn) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def identity(x): 794*da0073e9SAndroid Build Coastguard Worker return x.clone() 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4) 797*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(identity, inputs, vectorize=vectorize) 798*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, identity(inputs), inputs) 799*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res.grad_fn) 800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.eye(4)) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def add_exp_reducer(x, y): 803*da0073e9SAndroid Build Coastguard Worker return (x + y.exp()).sum(dim=1) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(4, 4), ctors.rand(4, 4)) 806*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(add_exp_reducer, inputs, vectorize=vectorize) 807*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs) 808*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 809*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker @vectorized_logging_tensor 812*da0073e9SAndroid Build Coastguard Worker def test_jacobian_scalar(self, vectorize, ctors): 813*da0073e9SAndroid Build Coastguard Worker def reducer(x): 814*da0073e9SAndroid Build Coastguard Worker return x.sum() 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 817*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(reducer, inputs, vectorize=vectorize) 818*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res, inputs) 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker def expander(x): 821*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(0).repeat(4) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 824*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian(expander, inputs, vectorize=vectorize) 825*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res, ctors.zeros(4)) 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker @parametrize("vectorize", [True, False]) 828*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 829*da0073e9SAndroid Build Coastguard Worker def test_jacobian_create_graph(self, vectorize, ctors): 830*da0073e9SAndroid Build Coastguard Worker def exp_reducer(x): 831*da0073e9SAndroid Build Coastguard Worker return x.exp().sum(dim=1) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 834*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian( 835*da0073e9SAndroid Build Coastguard Worker exp_reducer, inputs, create_graph=True, vectorize=vectorize 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, exp_reducer(inputs), inputs) 838*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res.grad_fn) 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker gradcheck( 841*da0073e9SAndroid Build Coastguard Worker lambda inp: autogradF.jacobian( 842*da0073e9SAndroid Build Coastguard Worker exp_reducer, inp, create_graph=True, vectorize=vectorize 843*da0073e9SAndroid Build Coastguard Worker ), 844*da0073e9SAndroid Build Coastguard Worker inputs, 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 847*da0073e9SAndroid Build Coastguard Worker lambda inp: autogradF.jacobian( 848*da0073e9SAndroid Build Coastguard Worker exp_reducer, inp, create_graph=True, vectorize=vectorize 849*da0073e9SAndroid Build Coastguard Worker ), 850*da0073e9SAndroid Build Coastguard Worker inputs, 851*da0073e9SAndroid Build Coastguard Worker ) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker def add_exp_reducer(x, y): 854*da0073e9SAndroid Build Coastguard Worker return (x + y).exp().sum(dim=1) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker inputs = ( 857*da0073e9SAndroid Build Coastguard Worker ctors.rand(4, 4, dtype=torch.double, requires_grad=True), 858*da0073e9SAndroid Build Coastguard Worker ctors.rand(4, 4, dtype=torch.double, requires_grad=True), 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker res = autogradF.jacobian( 861*da0073e9SAndroid Build Coastguard Worker add_exp_reducer, inputs, create_graph=True, vectorize=vectorize 862*da0073e9SAndroid Build Coastguard Worker ) 863*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs) 864*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 865*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker gradcheck( 868*da0073e9SAndroid Build Coastguard Worker lambda *inp: autogradF.jacobian( 869*da0073e9SAndroid Build Coastguard Worker add_exp_reducer, inp, create_graph=True, vectorize=vectorize 870*da0073e9SAndroid Build Coastguard Worker ), 871*da0073e9SAndroid Build Coastguard Worker inputs, 872*da0073e9SAndroid Build Coastguard Worker ) 873*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 874*da0073e9SAndroid Build Coastguard Worker lambda *inp: autogradF.jacobian( 875*da0073e9SAndroid Build Coastguard Worker add_exp_reducer, inp, create_graph=True, vectorize=vectorize 876*da0073e9SAndroid Build Coastguard Worker ), 877*da0073e9SAndroid Build Coastguard Worker inputs, 878*da0073e9SAndroid Build Coastguard Worker ) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 881*da0073e9SAndroid Build Coastguard Worker x = x.cos() 882*da0073e9SAndroid Build Coastguard Worker val, jac = autogradF.jacobian( 883*da0073e9SAndroid Build Coastguard Worker add_exp_reducer, (x, y), create_graph=True, vectorize=vectorize 884*da0073e9SAndroid Build Coastguard Worker ) 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker res = val[0].exp().sum() + val[1].exp().sum() + jac[0].exp().sum() 887*da0073e9SAndroid Build Coastguard Worker res = res + jac[1].exp().sum() + x.exp().sum() + y.exp().sum() 888*da0073e9SAndroid Build Coastguard Worker return res 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs) 891*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker def _check_jacobian_vectorize_correctness(self, f, inputs, test_forward_ad=True): 894*da0073e9SAndroid Build Coastguard Worker expected = autogradF.jacobian(f, inputs, vectorize=False) 895*da0073e9SAndroid Build Coastguard Worker result_backward_mode = autogradF.jacobian(f, inputs, vectorize=True) 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_backward_mode, expected) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker if test_forward_ad: 899*da0073e9SAndroid Build Coastguard Worker result_forward_mode = autogradF.jacobian( 900*da0073e9SAndroid Build Coastguard Worker f, inputs, strategy="forward-mode", vectorize=True 901*da0073e9SAndroid Build Coastguard Worker ) 902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_forward_mode, expected) 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 905*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_simple(self, ctors): 906*da0073e9SAndroid Build Coastguard Worker def f(x): 907*da0073e9SAndroid Build Coastguard Worker return 3 * x**2 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2, 3, 5) 910*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, x) 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 913*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_multi_input(self, ctors): 914*da0073e9SAndroid Build Coastguard Worker def f(x, y): 915*da0073e9SAndroid Build Coastguard Worker return (x.cos() * x) @ y.sin() 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2, 3) 918*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3, 5) 919*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y)) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 922*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_multi_input_multi_output(self, ctors): 923*da0073e9SAndroid Build Coastguard Worker def f(x, y): 924*da0073e9SAndroid Build Coastguard Worker return (x * x) @ y, x @ (x.sum(1) * y), y.sum() 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(5, 3) 927*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3, 5) 928*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y)) 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 931*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_unrelated_outputs(self, ctors): 932*da0073e9SAndroid Build Coastguard Worker def f(x, y): 933*da0073e9SAndroid Build Coastguard Worker return x, y, x, y 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2) 936*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 937*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y)) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 940*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_zero_dim(self, ctors): 941*da0073e9SAndroid Build Coastguard Worker # zero-dim output 942*da0073e9SAndroid Build Coastguard Worker def f(x, y): 943*da0073e9SAndroid Build Coastguard Worker return x.sum(), y.sum(), x * y 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(3) 946*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 947*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y)) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker # zero-dim input 950*da0073e9SAndroid Build Coastguard Worker def g(x): 951*da0073e9SAndroid Build Coastguard Worker return torch.stack([x, x, x]) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker x = ctors.randn([]) 954*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(g, x) 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker # Mixed zero-dim input / zero-dim output 957*da0073e9SAndroid Build Coastguard Worker def h(x, y): 958*da0073e9SAndroid Build Coastguard Worker return y.sum(), x * y 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker x = ctors.randn([]) 961*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(1) 962*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(h, (x, y)) 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 965*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 966*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_different_devices(self, ctors): 967*da0073e9SAndroid Build Coastguard Worker def f(x, y): 968*da0073e9SAndroid Build Coastguard Worker return x * y, (x * y).cuda() 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(3) 971*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 972*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y)) 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 975*da0073e9SAndroid Build Coastguard Worker def test_jacobian_vectorize_correctness_different_dtype(self, ctors): 976*da0073e9SAndroid Build Coastguard Worker def f(x, y): 977*da0073e9SAndroid Build Coastguard Worker return (x * y).float(), (x * y).double() 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(3) 980*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 981*da0073e9SAndroid Build Coastguard Worker # The Jacobian computed using forward AD has the dtype of the output 982*da0073e9SAndroid Build Coastguard Worker # but the Jacobian computed with reverse AD has dtype of input 983*da0073e9SAndroid Build Coastguard Worker self._check_jacobian_vectorize_correctness(f, (x, y), test_forward_ad=False) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker def _check_hessian_vectorize_correctness(self, f, inputs): 986*da0073e9SAndroid Build Coastguard Worker expected = autogradF.hessian(f, inputs, vectorize=False) 987*da0073e9SAndroid Build Coastguard Worker result = autogradF.hessian(f, inputs, vectorize=True) 988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker result_forward_mode = autogradF.hessian( 991*da0073e9SAndroid Build Coastguard Worker f, inputs, outer_jacobian_strategy="forward-mode", vectorize=True 992*da0073e9SAndroid Build Coastguard Worker ) 993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_forward_mode, expected) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 996*da0073e9SAndroid Build Coastguard Worker def test_hessian_vectorize_correctness_simple(self, ctors): 997*da0073e9SAndroid Build Coastguard Worker def f(x): 998*da0073e9SAndroid Build Coastguard Worker return (3 * x**2).sum() 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2, 3, 5) 1001*da0073e9SAndroid Build Coastguard Worker self._check_hessian_vectorize_correctness(f, x) 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1004*da0073e9SAndroid Build Coastguard Worker def test_hessian_vectorize_correctness_multi_input(self, ctors): 1005*da0073e9SAndroid Build Coastguard Worker def f(x, y, z): 1006*da0073e9SAndroid Build Coastguard Worker return ((x.relu() * x) @ y.sin() @ z).sum() 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2, 3) 1009*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3, 5) 1010*da0073e9SAndroid Build Coastguard Worker z = ctors.randn(5, 5) 1011*da0073e9SAndroid Build Coastguard Worker self._check_hessian_vectorize_correctness(f, (x, y, z)) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1014*da0073e9SAndroid Build Coastguard Worker def test_hessian_vectorize_correctness_unrelated_outputs(self, ctors): 1015*da0073e9SAndroid Build Coastguard Worker # output unrelated to one input 1016*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1017*da0073e9SAndroid Build Coastguard Worker return (x**2).sum() 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2) 1020*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 1021*da0073e9SAndroid Build Coastguard Worker self._check_hessian_vectorize_correctness(f, (x, y)) 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker # output unrelated to all inputs 1024*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1025*da0073e9SAndroid Build Coastguard Worker return ctors.ones([]) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker x = ctors.randn(2) 1028*da0073e9SAndroid Build Coastguard Worker y = ctors.randn(3) 1029*da0073e9SAndroid Build Coastguard Worker self._check_hessian_vectorize_correctness(f, (x, y)) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker @parametrize("vectorize", [True, False]) 1032*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1033*da0073e9SAndroid Build Coastguard Worker def test_hessian_err_check(self, vectorize, ctors): 1034*da0073e9SAndroid Build Coastguard Worker def foo(a): 1035*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker def bar(a): 1038*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 1039*da0073e9SAndroid Build Coastguard Worker 1040*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1041*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker def bar3(a): 1044*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), 3 * a.narrow(0, 0, 3) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1047*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1048*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to hessian must be either a Tensor" 1049*da0073e9SAndroid Build Coastguard Worker ): 1050*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, (inp, 2), vectorize=vectorize) 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1053*da0073e9SAndroid Build Coastguard Worker TypeError, "The outputs of the user-provided function given to hessian must" 1054*da0073e9SAndroid Build Coastguard Worker ): 1055*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar, inp, vectorize=vectorize) 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker err_msg_out = "The Tensor returned by the function given to hessian should contain a single element" 1058*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg_out): 1059*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar2, inp, vectorize=vectorize) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1062*da0073e9SAndroid Build Coastguard Worker RuntimeError, "The function given to hessian should return a single Tensor" 1063*da0073e9SAndroid Build Coastguard Worker ): 1064*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar3, inp, vectorize=vectorize) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, inp, vectorize=vectorize) 1067*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 1068*da0073e9SAndroid Build Coastguard Worker 1069*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 1070*da0073e9SAndroid Build Coastguard Worker return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker inp = (ctors.rand(4), ctors.rand(5)) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, inp, vectorize=vectorize) 1075*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 1076*da0073e9SAndroid Build Coastguard Worker 1077*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1078*da0073e9SAndroid Build Coastguard Worker def test_hessian_err_check_strict(self, ctors): 1079*da0073e9SAndroid Build Coastguard Worker def foo(a): 1080*da0073e9SAndroid Build Coastguard Worker return a.detach().sum() 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker def bar(a): 1083*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1084*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone().sum() 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1087*da0073e9SAndroid Build Coastguard Worker # A Linear function for which the jacobian is independent of the input 1088*da0073e9SAndroid Build Coastguard Worker return (3 * a).sum() 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1091*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1092*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1093*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 1094*da0073e9SAndroid Build Coastguard Worker ): 1095*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, inp, strict=True) 1096*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, inp, strict=False) 1097*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 1098*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.abs().sum(), 0.0) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1101*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1102*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function with respect to input 0", 1103*da0073e9SAndroid Build Coastguard Worker ): 1104*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar, inp, strict=True) 1105*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar, inp, strict=False) 1106*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 1107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.abs().sum(), 0.0) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1110*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1111*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function with respect to input 0 is", 1112*da0073e9SAndroid Build Coastguard Worker ): 1113*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar2, inp, strict=True) 1114*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bar2, inp, strict=False) 1115*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inp, inp) 1116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.abs().sum(), 0.0) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1119*da0073e9SAndroid Build Coastguard Worker def test_hessian_err_check_strict_vectorize(self, ctors): 1120*da0073e9SAndroid Build Coastguard Worker def foo(x): 1121*da0073e9SAndroid Build Coastguard Worker return (x**3).sum() 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1124*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not supported together"): 1125*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(foo, inp, strict=True, vectorize=True) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1128*da0073e9SAndroid Build Coastguard Worker def test_hessian_no_grad(self, ctors): 1129*da0073e9SAndroid Build Coastguard Worker def pow_reducer(x): 1130*da0073e9SAndroid Build Coastguard Worker return x.pow(3).sum() 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(2, 2) 1133*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1134*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(pow_reducer, inputs) 1135*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0][0].grad_fn) 1136*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0][1].grad_fn) 1137*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1][0].grad_fn) 1138*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1][1].grad_fn) 1139*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res, ctors.zeros(2, 2, 2)) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1142*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(pow_reducer, inputs, create_graph=True) 1143*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0][0].grad_fn) 1144*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0][1].grad_fn) 1145*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1][0].grad_fn) 1146*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1][1].grad_fn) 1147*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res, ctors.zeros(2, 2, 2)) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker @vectorized_logging_tensor 1150*da0073e9SAndroid Build Coastguard Worker def test_hessian_output(self, vectorize, ctors): 1151*da0073e9SAndroid Build Coastguard Worker def pow_reducer(x): 1152*da0073e9SAndroid Build Coastguard Worker return x.pow(3).sum() 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(2, 2) 1155*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(pow_reducer, inputs, vectorize=vectorize) 1156*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1157*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res.grad_fn) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker def add_pow_reducer(x, y): 1160*da0073e9SAndroid Build Coastguard Worker return (x + y).pow(3).sum() 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(2, 2), ctors.rand(2, 2)) 1163*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(add_pow_reducer, inputs, vectorize=vectorize) 1164*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1165*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0][0].grad_fn) 1166*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0][1].grad_fn) 1167*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1][0].grad_fn) 1168*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1][1].grad_fn) 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker @parametrize("vectorize", [True, False]) 1171*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1172*da0073e9SAndroid Build Coastguard Worker def test_hessian_scalar(self, vectorize, ctors): 1173*da0073e9SAndroid Build Coastguard Worker def reducer(x): 1174*da0073e9SAndroid Build Coastguard Worker return x.sum() 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1177*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(reducer, inputs, vectorize=vectorize) 1178*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 1181*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(reducer, inputs, vectorize=vectorize) 1182*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res, inputs) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker def bad_reducer(x): 1185*da0073e9SAndroid Build Coastguard Worker return x.sum().view(1, 1, 1) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1188*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian(bad_reducer, inputs, vectorize=vectorize) 1189*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker @parametrize("vectorize", [True, False]) 1192*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1193*da0073e9SAndroid Build Coastguard Worker def test_hessian_create_graph(self, vectorize, ctors): 1194*da0073e9SAndroid Build Coastguard Worker def pow_reducer(x): 1195*da0073e9SAndroid Build Coastguard Worker return x.pow(3).sum() 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(2, 2, dtype=torch.double, requires_grad=True) 1198*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian( 1199*da0073e9SAndroid Build Coastguard Worker pow_reducer, inputs, create_graph=True, vectorize=vectorize 1200*da0073e9SAndroid Build Coastguard Worker ) 1201*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1202*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res.grad_fn) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker gradcheck( 1205*da0073e9SAndroid Build Coastguard Worker lambda inp: autogradF.hessian( 1206*da0073e9SAndroid Build Coastguard Worker pow_reducer, inp, create_graph=True, vectorize=vectorize 1207*da0073e9SAndroid Build Coastguard Worker ), 1208*da0073e9SAndroid Build Coastguard Worker inputs, 1209*da0073e9SAndroid Build Coastguard Worker ) 1210*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1211*da0073e9SAndroid Build Coastguard Worker lambda inp: autogradF.hessian( 1212*da0073e9SAndroid Build Coastguard Worker pow_reducer, inp, create_graph=True, vectorize=vectorize 1213*da0073e9SAndroid Build Coastguard Worker ), 1214*da0073e9SAndroid Build Coastguard Worker inputs, 1215*da0073e9SAndroid Build Coastguard Worker ) 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker def add_pow_reducer(x, y): 1218*da0073e9SAndroid Build Coastguard Worker return (x + y).pow(3).sum() 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker inputs = ( 1221*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, 2, dtype=torch.double, requires_grad=True), 1222*da0073e9SAndroid Build Coastguard Worker ctors.rand(2, 2, dtype=torch.double, requires_grad=True), 1223*da0073e9SAndroid Build Coastguard Worker ) 1224*da0073e9SAndroid Build Coastguard Worker res = autogradF.hessian( 1225*da0073e9SAndroid Build Coastguard Worker add_pow_reducer, inputs, create_graph=True, vectorize=vectorize 1226*da0073e9SAndroid Build Coastguard Worker ) 1227*da0073e9SAndroid Build Coastguard Worker self._assert_interleaved_struct(res, inputs, inputs) 1228*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0][0].grad_fn) 1229*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0][1].grad_fn) 1230*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1][0].grad_fn) 1231*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1][1].grad_fn) 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker def flatten(inp): 1234*da0073e9SAndroid Build Coastguard Worker return tuple(el_lvl2 for el_lvl1 in inp for el_lvl2 in el_lvl1) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker gradcheck( 1237*da0073e9SAndroid Build Coastguard Worker lambda *inp: flatten( 1238*da0073e9SAndroid Build Coastguard Worker autogradF.hessian( 1239*da0073e9SAndroid Build Coastguard Worker add_pow_reducer, inp, create_graph=True, vectorize=vectorize 1240*da0073e9SAndroid Build Coastguard Worker ) 1241*da0073e9SAndroid Build Coastguard Worker ), 1242*da0073e9SAndroid Build Coastguard Worker inputs, 1243*da0073e9SAndroid Build Coastguard Worker ) 1244*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1245*da0073e9SAndroid Build Coastguard Worker lambda *inp: flatten( 1246*da0073e9SAndroid Build Coastguard Worker autogradF.hessian( 1247*da0073e9SAndroid Build Coastguard Worker add_pow_reducer, inp, create_graph=True, vectorize=vectorize 1248*da0073e9SAndroid Build Coastguard Worker ) 1249*da0073e9SAndroid Build Coastguard Worker ), 1250*da0073e9SAndroid Build Coastguard Worker inputs, 1251*da0073e9SAndroid Build Coastguard Worker ) 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 1254*da0073e9SAndroid Build Coastguard Worker x = x.cos() 1255*da0073e9SAndroid Build Coastguard Worker val, hess = autogradF.hessian( 1256*da0073e9SAndroid Build Coastguard Worker add_pow_reducer, (x, y), create_graph=True, vectorize=vectorize 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker res = val[0].cos().sum() + val[1].cos().sum() + hess[0].cos().sum() 1260*da0073e9SAndroid Build Coastguard Worker res = res + hess[1].cos().sum() + x.cos().sum() + y.cos().sum() 1261*da0073e9SAndroid Build Coastguard Worker return res 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs) 1264*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs) 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1267*da0073e9SAndroid Build Coastguard Worker def test_vhp_err_check(self, ctors): 1268*da0073e9SAndroid Build Coastguard Worker def foo(a): 1269*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker def bar(a): 1272*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1275*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1278*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1279*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1280*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to vhp must be either a Tensor" 1281*da0073e9SAndroid Build Coastguard Worker ): 1282*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, (inp, 2), v) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1285*da0073e9SAndroid Build Coastguard Worker TypeError, "The outputs of the user-provided function given to vhp must" 1286*da0073e9SAndroid Build Coastguard Worker ): 1287*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar, inp, v) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker err_msg_out = "The Tensor returned by the function given to vhp should contain a single element" 1290*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg_out): 1291*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar2, inp, v) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "v has invalid size:"): 1294*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, ctors.rand(5)) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1297*da0073e9SAndroid Build Coastguard Worker TypeError, 1298*da0073e9SAndroid Build Coastguard Worker "The v given to vhp must be either a Tensor or a tuple of Tensors", 1299*da0073e9SAndroid Build Coastguard Worker ): 1300*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, (v, 2)) 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, v) 1303*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 1306*da0073e9SAndroid Build Coastguard Worker return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker inp = (ctors.rand(4), ctors.rand(5)) 1309*da0073e9SAndroid Build Coastguard Worker v = (ctors.rand(4), ctors.rand(5)) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, v) 1312*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1315*da0073e9SAndroid Build Coastguard Worker def test_vhp_err_check_strict(self, ctors): 1316*da0073e9SAndroid Build Coastguard Worker def foo(a): 1317*da0073e9SAndroid Build Coastguard Worker return a.detach().sum() 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker def bar(a): 1320*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1321*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone().sum() 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1324*da0073e9SAndroid Build Coastguard Worker # A Linear function for which the jacobian is independent of the input 1325*da0073e9SAndroid Build Coastguard Worker return (3 * a).sum() 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1328*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1329*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1330*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1331*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 1332*da0073e9SAndroid Build Coastguard Worker ): 1333*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, v, strict=True) 1334*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inp, v, strict=False) 1335*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1339*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1340*da0073e9SAndroid Build Coastguard Worker "The output of the user-provided function is independent of input 0", 1341*da0073e9SAndroid Build Coastguard Worker ): 1342*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar, inp, v, strict=True) 1343*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar, inp, v, strict=False) 1344*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1348*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1349*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function with respect to input 0 is", 1350*da0073e9SAndroid Build Coastguard Worker ): 1351*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar2, inp, v, strict=True) 1352*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bar2, inp, v, strict=False) 1353*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1357*da0073e9SAndroid Build Coastguard Worker def test_vhp_no_grad(self, ctors): 1358*da0073e9SAndroid Build Coastguard Worker def reducer(x): 1359*da0073e9SAndroid Build Coastguard Worker return x.exp().sum() 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1362*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1363*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1364*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(reducer, inputs, v) 1365*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 1366*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 1367*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1370*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(reducer, inputs, v, create_graph=True) 1371*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 1372*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 1373*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1376*da0073e9SAndroid Build Coastguard Worker def test_vhp_output(self, ctors): 1377*da0073e9SAndroid Build Coastguard Worker def foo(a): 1378*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1381*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1382*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inputs, v) 1383*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1384*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 1385*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker def bar(a, b): 1388*da0073e9SAndroid Build Coastguard Worker return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(3), ctors.rand(4)) 1391*da0073e9SAndroid Build Coastguard Worker v = (ctors.ones(3), ctors.ones(4)) 1392*da0073e9SAndroid Build Coastguard Worker out, vhp_val = autogradF.vhp(bar, inputs, v) 1393*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(vhp_val, inputs) 1394*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn) 1395*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vhp_val[0].grad_fn) 1396*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(vhp_val[1].grad_fn) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1399*da0073e9SAndroid Build Coastguard Worker def test_vhp_scalar(self, ctors): 1400*da0073e9SAndroid Build Coastguard Worker def reducer(x): 1401*da0073e9SAndroid Build Coastguard Worker return x.sum() 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1404*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1405*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(reducer, inputs, v) 1406*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 1409*da0073e9SAndroid Build Coastguard Worker v = ctors.rand([]) 1410*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(reducer, inputs, v) 1411*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(reducer, inputs) 1414*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker def bad_reducer(x): 1417*da0073e9SAndroid Build Coastguard Worker return x.sum().view(1, 1, 1) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1420*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4, 4) 1421*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(bad_reducer, inputs, v) 1422*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1425*da0073e9SAndroid Build Coastguard Worker def test_vhp_create_graph(self, ctors): 1426*da0073e9SAndroid Build Coastguard Worker def foo(a): 1427*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 1430*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True) 1431*da0073e9SAndroid Build Coastguard Worker res = autogradF.vhp(foo, inputs, v, create_graph=True) 1432*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1433*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 1434*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker gradcheck( 1437*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v) 1438*da0073e9SAndroid Build Coastguard Worker ) 1439*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1440*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v) 1441*da0073e9SAndroid Build Coastguard Worker ) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker def bar(a, b): 1444*da0073e9SAndroid Build Coastguard Worker return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker inputs = ( 1447*da0073e9SAndroid Build Coastguard Worker ctors.rand(3, dtype=torch.double, requires_grad=True), 1448*da0073e9SAndroid Build Coastguard Worker ctors.rand(4, dtype=torch.double, requires_grad=True), 1449*da0073e9SAndroid Build Coastguard Worker ) 1450*da0073e9SAndroid Build Coastguard Worker v = ( 1451*da0073e9SAndroid Build Coastguard Worker ctors.ones(3, dtype=torch.double, requires_grad=True), 1452*da0073e9SAndroid Build Coastguard Worker ctors.ones(4, dtype=torch.double, requires_grad=True), 1453*da0073e9SAndroid Build Coastguard Worker ) 1454*da0073e9SAndroid Build Coastguard Worker out, vhp_val = autogradF.vhp(bar, inputs, v, create_graph=True) 1455*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(vhp_val, inputs) 1456*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(out.grad_fn) 1457*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(vhp_val[0].grad_fn) 1458*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(vhp_val[1].grad_fn) 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker gradcheck( 1461*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1], 1462*da0073e9SAndroid Build Coastguard Worker inputs + v, 1463*da0073e9SAndroid Build Coastguard Worker ) 1464*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1465*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1], 1466*da0073e9SAndroid Build Coastguard Worker inputs + v, 1467*da0073e9SAndroid Build Coastguard Worker ) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1470*da0073e9SAndroid Build Coastguard Worker x, y = args[:2] 1471*da0073e9SAndroid Build Coastguard Worker v = args[2:] 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker x = x.cos() 1474*da0073e9SAndroid Build Coastguard Worker val, grad = autogradF.vhp(bar, (x, y), v, create_graph=True) 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker return ( 1477*da0073e9SAndroid Build Coastguard Worker val.cos() 1478*da0073e9SAndroid Build Coastguard Worker + grad[0].cos().sum() 1479*da0073e9SAndroid Build Coastguard Worker + grad[1].cos() 1480*da0073e9SAndroid Build Coastguard Worker + x.cos().sum() 1481*da0073e9SAndroid Build Coastguard Worker + y.cos() 1482*da0073e9SAndroid Build Coastguard Worker ) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs + v) 1485*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs + v) 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1488*da0073e9SAndroid Build Coastguard Worker def test_hvp_err_check(self, ctors): 1489*da0073e9SAndroid Build Coastguard Worker def foo(a): 1490*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker def bar(a): 1493*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3), "bar" 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1496*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3) 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1499*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1500*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, v) 1501*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1502*da0073e9SAndroid Build Coastguard Worker TypeError, "The inputs given to hvp must be either a Tensor" 1503*da0073e9SAndroid Build Coastguard Worker ): 1504*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, (inp, 2), v) 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1507*da0073e9SAndroid Build Coastguard Worker TypeError, "The outputs of the user-provided function given to hvp must" 1508*da0073e9SAndroid Build Coastguard Worker ): 1509*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar, inp, v) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker err_msg_out = "The Tensor returned by the function given to hvp should contain a single element" 1512*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg_out): 1513*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar2, inp, v) 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "v has invalid size:"): 1516*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, ctors.rand(5)) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1519*da0073e9SAndroid Build Coastguard Worker TypeError, 1520*da0073e9SAndroid Build Coastguard Worker "The v given to hvp must be either a Tensor or a tuple of Tensors", 1521*da0073e9SAndroid Build Coastguard Worker ): 1522*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, (v, 2)) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, v) 1525*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 1528*da0073e9SAndroid Build Coastguard Worker return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker inp = (ctors.rand(4), ctors.rand(5)) 1531*da0073e9SAndroid Build Coastguard Worker v = (ctors.rand(4), ctors.rand(5)) 1532*da0073e9SAndroid Build Coastguard Worker 1533*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, v) 1534*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1537*da0073e9SAndroid Build Coastguard Worker def test_hvp_err_check_strict(self, ctors): 1538*da0073e9SAndroid Build Coastguard Worker def foo(a): 1539*da0073e9SAndroid Build Coastguard Worker return a.detach().sum() 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker def bar(a): 1542*da0073e9SAndroid Build Coastguard Worker # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1543*da0073e9SAndroid Build Coastguard Worker return a.long().float().requires_grad_().clone().sum() 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker def bar2(a): 1546*da0073e9SAndroid Build Coastguard Worker # A Linear function for which the jacobian is independent of the input 1547*da0073e9SAndroid Build Coastguard Worker return (3 * a).sum() 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker inp = ctors.rand(4) 1550*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1551*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1552*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1553*da0073e9SAndroid Build Coastguard Worker "Output 0 of the user-provided function does not require gradients.", 1554*da0073e9SAndroid Build Coastguard Worker ): 1555*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, v, strict=True) 1556*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inp, v, strict=False) 1557*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1559*da0073e9SAndroid Build Coastguard Worker 1560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1561*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1562*da0073e9SAndroid Build Coastguard Worker "The output of the user-provided function is independent of input 0", 1563*da0073e9SAndroid Build Coastguard Worker ): 1564*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar, inp, v, strict=True) 1565*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar, inp, v, strict=False) 1566*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1570*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1571*da0073e9SAndroid Build Coastguard Worker "jacobian of the user-provided function with respect to input 0 is", 1572*da0073e9SAndroid Build Coastguard Worker ): 1573*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar2, inp, v, strict=True) 1574*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bar2, inp, v, strict=False) 1575*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inp) 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[1].abs().sum(), 0.0) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1579*da0073e9SAndroid Build Coastguard Worker def test_hvp_no_grad(self, ctors): 1580*da0073e9SAndroid Build Coastguard Worker def reducer(x): 1581*da0073e9SAndroid Build Coastguard Worker return x.exp().sum() 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1584*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1585*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1586*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(reducer, inputs, v) 1587*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 1588*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 1589*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1592*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(reducer, inputs, v, create_graph=True) 1593*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 1594*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 1595*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1598*da0073e9SAndroid Build Coastguard Worker def test_hvp_output(self, ctors): 1599*da0073e9SAndroid Build Coastguard Worker def foo(a): 1600*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1603*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1604*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inputs, v) 1605*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1606*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[0].grad_fn) 1607*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(res[1].grad_fn) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker def bar(a, b): 1610*da0073e9SAndroid Build Coastguard Worker return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker inputs = (ctors.rand(3), ctors.rand(4)) 1613*da0073e9SAndroid Build Coastguard Worker v = (ctors.ones(3), ctors.ones(4)) 1614*da0073e9SAndroid Build Coastguard Worker out, hvp_val = autogradF.hvp(bar, inputs, v) 1615*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(hvp_val, inputs) 1616*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn) 1617*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(hvp_val[0].grad_fn) 1618*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(hvp_val[1].grad_fn) 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1621*da0073e9SAndroid Build Coastguard Worker def test_hvp_scalar(self, ctors): 1622*da0073e9SAndroid Build Coastguard Worker def reducer(x): 1623*da0073e9SAndroid Build Coastguard Worker return x.exp().sum() 1624*da0073e9SAndroid Build Coastguard Worker 1625*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1626*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4) 1627*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(reducer, inputs, v) 1628*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand([]) 1631*da0073e9SAndroid Build Coastguard Worker v = ctors.rand([]) 1632*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(reducer, inputs, v) 1633*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(reducer, inputs) 1636*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker def bad_reducer(x): 1639*da0073e9SAndroid Build Coastguard Worker return x.exp().sum().view(1, 1, 1) 1640*da0073e9SAndroid Build Coastguard Worker 1641*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4) 1642*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4, 4) 1643*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(bad_reducer, inputs, v) 1644*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1647*da0073e9SAndroid Build Coastguard Worker def test_hvp_create_graph(self, ctors): 1648*da0073e9SAndroid Build Coastguard Worker def foo(a): 1649*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1650*da0073e9SAndroid Build Coastguard Worker 1651*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 1652*da0073e9SAndroid Build Coastguard Worker v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True) 1653*da0073e9SAndroid Build Coastguard Worker res = autogradF.hvp(foo, inputs, v, create_graph=True) 1654*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(res[1], inputs) 1655*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[0].grad_fn) 1656*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(res[1].grad_fn) 1657*da0073e9SAndroid Build Coastguard Worker 1658*da0073e9SAndroid Build Coastguard Worker gradcheck( 1659*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v) 1660*da0073e9SAndroid Build Coastguard Worker ) 1661*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1662*da0073e9SAndroid Build Coastguard Worker lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v) 1663*da0073e9SAndroid Build Coastguard Worker ) 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker def bar(a, b): 1666*da0073e9SAndroid Build Coastguard Worker return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker inputs = ( 1669*da0073e9SAndroid Build Coastguard Worker ctors.rand(3, dtype=torch.double, requires_grad=True), 1670*da0073e9SAndroid Build Coastguard Worker ctors.rand(4, dtype=torch.double, requires_grad=True), 1671*da0073e9SAndroid Build Coastguard Worker ) 1672*da0073e9SAndroid Build Coastguard Worker v = ( 1673*da0073e9SAndroid Build Coastguard Worker ctors.ones(3, dtype=torch.double, requires_grad=True), 1674*da0073e9SAndroid Build Coastguard Worker ctors.ones(4, dtype=torch.double, requires_grad=True), 1675*da0073e9SAndroid Build Coastguard Worker ) 1676*da0073e9SAndroid Build Coastguard Worker out, hvp_val = autogradF.hvp(bar, inputs, v, create_graph=True) 1677*da0073e9SAndroid Build Coastguard Worker self._assert_same_struct(hvp_val, inputs) 1678*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(out.grad_fn) 1679*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(hvp_val[0].grad_fn) 1680*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(hvp_val[1].grad_fn) 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker gradcheck( 1683*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1], 1684*da0073e9SAndroid Build Coastguard Worker inputs + v, 1685*da0073e9SAndroid Build Coastguard Worker ) 1686*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 1687*da0073e9SAndroid Build Coastguard Worker lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1], 1688*da0073e9SAndroid Build Coastguard Worker inputs + v, 1689*da0073e9SAndroid Build Coastguard Worker ) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1692*da0073e9SAndroid Build Coastguard Worker x, y = args[:2] 1693*da0073e9SAndroid Build Coastguard Worker v = args[2:] 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker x = x.cos() 1696*da0073e9SAndroid Build Coastguard Worker val, grad = autogradF.hvp(bar, (x, y), v, create_graph=True) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker return ( 1699*da0073e9SAndroid Build Coastguard Worker val.cos() 1700*da0073e9SAndroid Build Coastguard Worker + grad[0].cos().sum() 1701*da0073e9SAndroid Build Coastguard Worker + grad[1].cos() 1702*da0073e9SAndroid Build Coastguard Worker + x.cos().sum() 1703*da0073e9SAndroid Build Coastguard Worker + y.cos() 1704*da0073e9SAndroid Build Coastguard Worker ) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, inputs + v) 1707*da0073e9SAndroid Build Coastguard Worker gradgradcheck(foo, inputs + v) 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1710*da0073e9SAndroid Build Coastguard Worker def test_jacobian_match_vjp_jvp(self, ctors): 1711*da0073e9SAndroid Build Coastguard Worker def foo(x): 1712*da0073e9SAndroid Build Coastguard Worker return x**3 + x.sum() 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4) 1715*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker jac = autogradF.jacobian(foo, inputs) 1718*da0073e9SAndroid Build Coastguard Worker jvp = autogradF.jvp(foo, inputs, v)[1] 1719*da0073e9SAndroid Build Coastguard Worker vjp = autogradF.vjp(foo, inputs, v)[1] 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jvp, torch.mm(jac, v.unsqueeze(1)).squeeze(1)) 1722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vjp, torch.mm(v.unsqueeze(0), jac).squeeze(0)) 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker @base_and_logging_tensor 1725*da0073e9SAndroid Build Coastguard Worker def test_hessian_match_vhp_hvp(self, ctors): 1726*da0073e9SAndroid Build Coastguard Worker def foo(a): 1727*da0073e9SAndroid Build Coastguard Worker return 3 * a.narrow(0, 0, 3).exp().sum() 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker inputs = ctors.rand(4) 1730*da0073e9SAndroid Build Coastguard Worker v = ctors.rand(4) 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker hes = autogradF.hessian(foo, inputs) 1733*da0073e9SAndroid Build Coastguard Worker hvp = autogradF.hvp(foo, inputs, v)[1] 1734*da0073e9SAndroid Build Coastguard Worker vhp = autogradF.vhp(foo, inputs, v)[1] 1735*da0073e9SAndroid Build Coastguard Worker 1736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1)) 1737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0)) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutogradFunctional) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1743*da0073e9SAndroid Build Coastguard Worker run_tests() 1744