1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport tempfile 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.show_pickle 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass TestShowPickle(TestCase): 13*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") 14*da0073e9SAndroid Build Coastguard Worker def test_scripted_model(self): 15*da0073e9SAndroid Build Coastguard Worker class MyCoolModule(torch.nn.Module): 16*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight): 17*da0073e9SAndroid Build Coastguard Worker super().__init__() 18*da0073e9SAndroid Build Coastguard Worker self.weight = weight 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 21*da0073e9SAndroid Build Coastguard Worker return x * self.weight 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(MyCoolModule(torch.tensor([2.0]))) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as tmp: 26*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, tmp) 27*da0073e9SAndroid Build Coastguard Worker tmp.flush() 28*da0073e9SAndroid Build Coastguard Worker buf = io.StringIO() 29*da0073e9SAndroid Build Coastguard Worker torch.utils.show_pickle.main( 30*da0073e9SAndroid Build Coastguard Worker ["", tmp.name + "@*/data.pkl"], output_stream=buf 31*da0073e9SAndroid Build Coastguard Worker ) 32*da0073e9SAndroid Build Coastguard Worker output = buf.getvalue() 33*da0073e9SAndroid Build Coastguard Worker self.assertRegex(output, "MyCoolModule") 34*da0073e9SAndroid Build Coastguard Worker self.assertRegex(output, "weight") 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 38*da0073e9SAndroid Build Coastguard Worker run_tests() 39