1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5 6import torch 7from torch import nn 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16if __name__ == "__main__": 17 raise RuntimeError( 18 "This test file is not meant to be run directly, use:\n\n" 19 "\tpython test/test_jit.py TESTNAME\n\n" 20 "instead." 21 ) 22 23 24class Sequence(nn.Module): 25 def __init__(self) -> None: 26 super().__init__() 27 self.lstm1 = nn.LSTMCell(1, 51) 28 self.lstm2 = nn.LSTMCell(51, 51) 29 self.linear = nn.Linear(51, 1) 30 31 def forward(self, input): 32 outputs = [] 33 h_t = torch.zeros(input.size(0), 51) 34 c_t = torch.zeros(input.size(0), 51) 35 h_t2 = torch.zeros(input.size(0), 51) 36 c_t2 = torch.zeros(input.size(0), 51) 37 38 for input_t in input.split(1, dim=1): 39 h_t, c_t = self.lstm1(input_t, (h_t, c_t)) 40 h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) 41 output = self.linear(h_t2) 42 outputs += [output] 43 outputs = torch.cat(outputs, dim=1) 44 return outputs 45 46 47class TestScriptProfile(JitTestCase): 48 def test_basic(self): 49 seq = torch.jit.script(Sequence()) 50 p = torch.jit._ScriptProfile() 51 p.enable() 52 seq(torch.rand((10, 100))) 53 p.disable() 54 self.assertNotEqual(p.dump_string(), "") 55 56 def test_script(self): 57 seq = Sequence() 58 59 p = torch.jit._ScriptProfile() 60 p.enable() 61 62 @torch.jit.script 63 def fn(): 64 _ = seq(torch.rand((10, 100))) 65 66 fn() 67 p.disable() 68 69 self.assertNotEqual(p.dump_string(), "") 70 71 def test_multi(self): 72 seq = torch.jit.script(Sequence()) 73 profiles = [torch.jit._ScriptProfile() for _ in range(5)] 74 for p in profiles: 75 p.enable() 76 77 last = None 78 while len(profiles) > 0: 79 seq(torch.rand((10, 10))) 80 p = profiles.pop() 81 p.disable() 82 stats = p.dump_string() 83 self.assertNotEqual(stats, "") 84 if last: 85 self.assertNotEqual(stats, last) 86 last = stats 87 88 def test_section(self): 89 seq = Sequence() 90 91 @torch.jit.script 92 def fn(max: int): 93 _ = seq(torch.rand((10, max))) 94 95 p = torch.jit._ScriptProfile() 96 p.enable() 97 fn(100) 98 p.disable() 99 s0 = p.dump_string() 100 101 fn(10) 102 p.disable() 103 s1 = p.dump_string() 104 105 p.enable() 106 fn(10) 107 p.disable() 108 s2 = p.dump_string() 109 110 self.assertEqual(s0, s1) 111 self.assertNotEqual(s1, s2) 112 113 def test_empty(self): 114 p = torch.jit._ScriptProfile() 115 p.enable() 116 p.disable() 117 self.assertEqual(p.dump_string(), "") 118