1# Owner(s): ["oncall: mobile"] 2 3import io 4import tempfile 5import unittest 6 7import torch 8import torch.utils.show_pickle 9from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase 10 11 12class TestShowPickle(TestCase): 13 @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") 14 def test_scripted_model(self): 15 class MyCoolModule(torch.nn.Module): 16 def __init__(self, weight): 17 super().__init__() 18 self.weight = weight 19 20 def forward(self, x): 21 return x * self.weight 22 23 m = torch.jit.script(MyCoolModule(torch.tensor([2.0]))) 24 25 with tempfile.NamedTemporaryFile() as tmp: 26 torch.jit.save(m, tmp) 27 tmp.flush() 28 buf = io.StringIO() 29 torch.utils.show_pickle.main( 30 ["", tmp.name + "@*/data.pkl"], output_stream=buf 31 ) 32 output = buf.getvalue() 33 self.assertRegex(output, "MyCoolModule") 34 self.assertRegex(output, "weight") 35 36 37if __name__ == "__main__": 38 run_tests() 39