1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23# NOTE: FIXING FAILING TESTS 24# If you are seeing a test failure from this file, congrats, you improved 25# parity between JIT and Python API. Before you fix the test, you must also update 26# the corresponding section in documentation that states the unsupported behavior. 27# see: `jit_unsupported.rst` 28 29 30class TestUnsupportedOps(JitTestCase): 31 def test_factory_ops_requires_grad_fail(self): 32 # Keyword argument {name} unknown is a JIT-only error message, 33 # so these functions are succeeding in eager and failing in JIT 34 35 # Complete issue and set of ops is https://github.com/pytorch/pytorch/issues/30761 36 # only testing some because they should be fixed all at once 37 def ones(): 38 return torch.ones([2], requires_grad=True) 39 40 with self.assertRaisesRegexWithHighlight( 41 Exception, "Keyword argument requires_grad unknown", "torch.ones" 42 ): 43 torch.jit.script(ones) 44 45 def randn(): 46 return torch.randn([2], requires_grad=True) 47 48 with self.assertRaisesRegexWithHighlight( 49 Exception, "Keyword argument requires_grad unknown", "torch.randn" 50 ): 51 torch.jit.script(randn) 52 53 def zeros(): 54 return torch.zeros([2], requires_grad=True) 55 56 with self.assertRaisesRegexWithHighlight( 57 Exception, "Keyword argument requires_grad unknown", "torch.zeros" 58 ): 59 torch.jit.script(zeros) 60 61 @unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack") 62 def test_init_ops(self): 63 def calculate_gain(): 64 return torch.nn.init.calculate_gain("leaky_relu", 0.2) 65 66 def eye_(): 67 return torch.nn.init.eye_(torch.zeros([2, 2])) 68 69 def dirac_(): 70 return torch.nn.init.dirac_(torch.empty(3, 16, 5, 5)) 71 72 def kaiming_uniform_(): 73 return torch.nn.init.kaiming_normal_(torch.empty(3, 5)) 74 75 def orthogonal_(): 76 return torch.nn.init.orthogonal_(torch.empty(3, 5)) 77 78 def sparse(): 79 return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1) 80 81 for func in [ 82 calculate_gain, 83 eye_, 84 dirac_, 85 kaiming_uniform_, 86 orthogonal_, 87 sparse, 88 ]: 89 # doesn't error in eager 90 func() 91 with self.assertRaisesRegex(Exception, ""): 92 torch.jit.script(func) 93