1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing._internal.jit_utils import JitTestCase 5 6 7class TestFuserCommon(JitTestCase): 8 def test_autodiff_fallback(self): 9 for rq in [True, False]: 10 11 @torch.jit.script 12 def fn(x): 13 return torch.max(x**2.0, x**3.0) 14 15 x = torch.randn(5, requires_grad=not rq) 16 # cause optimization to be created 17 for i in range(5): 18 fn(x) 19 # test fallback when optimization is not applicable 20 y = fn(torch.randn(5, requires_grad=rq)) 21 self.assertEqual(y.requires_grad, rq) 22