1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4 5import torch 6from torch.fx import Graph, GraphModule, symbolic_trace 7from torch.package import ( 8 ObjMismatchError, 9 PackageExporter, 10 PackageImporter, 11 sys_importer, 12) 13from torch.testing._internal.common_utils import run_tests 14 15 16try: 17 from .common import PackageTestCase 18except ImportError: 19 # Support the case where we run this file directly. 20 from common import PackageTestCase 21 22torch.fx.wrap("len") 23# Do it twice to make sure it doesn't affect anything 24torch.fx.wrap("len") 25 26 27class TestPackageFX(PackageTestCase): 28 """Tests for compatibility with FX.""" 29 30 def test_package_fx_simple(self): 31 class SimpleTest(torch.nn.Module): 32 def forward(self, x): 33 return torch.relu(x + 3.0) 34 35 st = SimpleTest() 36 traced = symbolic_trace(st) 37 38 f = BytesIO() 39 with PackageExporter(f) as pe: 40 pe.save_pickle("model", "model.pkl", traced) 41 42 f.seek(0) 43 pi = PackageImporter(f) 44 loaded_traced = pi.load_pickle("model", "model.pkl") 45 input = torch.rand(2, 3) 46 self.assertEqual(loaded_traced(input), traced(input)) 47 48 def test_package_then_fx(self): 49 from package_a.test_module import SimpleTest 50 51 model = SimpleTest() 52 f = BytesIO() 53 with PackageExporter(f) as pe: 54 pe.intern("**") 55 pe.save_pickle("model", "model.pkl", model) 56 57 f.seek(0) 58 pi = PackageImporter(f) 59 loaded = pi.load_pickle("model", "model.pkl") 60 traced = symbolic_trace(loaded) 61 input = torch.rand(2, 3) 62 self.assertEqual(loaded(input), traced(input)) 63 64 def test_package_fx_package(self): 65 from package_a.test_module import SimpleTest 66 67 model = SimpleTest() 68 f = BytesIO() 69 with PackageExporter(f) as pe: 70 pe.intern("**") 71 pe.save_pickle("model", "model.pkl", model) 72 73 f.seek(0) 74 pi = PackageImporter(f) 75 loaded = pi.load_pickle("model", "model.pkl") 76 traced = symbolic_trace(loaded) 77 78 # re-save the package exporter 79 f2 = BytesIO() 80 # This should fail, because we are referencing some globals that are 81 # only in the package. 82 with self.assertRaises(ObjMismatchError): 83 with PackageExporter(f2) as pe: 84 pe.intern("**") 85 pe.save_pickle("model", "model.pkl", traced) 86 87 f2.seek(0) 88 with PackageExporter(f2, importer=(pi, sys_importer)) as pe: 89 # Make the package available to the exporter's environment. 90 pe.intern("**") 91 pe.save_pickle("model", "model.pkl", traced) 92 f2.seek(0) 93 pi2 = PackageImporter(f2) 94 loaded2 = pi2.load_pickle("model", "model.pkl") 95 96 input = torch.rand(2, 3) 97 self.assertEqual(loaded(input), loaded2(input)) 98 99 def test_package_fx_with_imports(self): 100 import package_a.subpackage 101 102 # Manually construct a graph that invokes a leaf function 103 graph = Graph() 104 a = graph.placeholder("x") 105 b = graph.placeholder("y") 106 c = graph.call_function(package_a.subpackage.leaf_function, (a, b)) 107 d = graph.call_function(torch.sin, (c,)) 108 graph.output(d) 109 gm = GraphModule(torch.nn.Module(), graph) 110 111 f = BytesIO() 112 with PackageExporter(f) as pe: 113 pe.intern("**") 114 pe.save_pickle("model", "model.pkl", gm) 115 f.seek(0) 116 117 pi = PackageImporter(f) 118 loaded_gm = pi.load_pickle("model", "model.pkl") 119 input_x = torch.rand(2, 3) 120 input_y = torch.rand(2, 3) 121 122 self.assertTrue( 123 torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)) 124 ) 125 126 # Check that the packaged version of the leaf_function dependency is 127 # not the same as in the outer env. 128 packaged_dependency = pi.import_module("package_a.subpackage") 129 self.assertTrue(packaged_dependency is not package_a.subpackage) 130 131 def test_package_fx_custom_tracer(self): 132 from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer 133 from package_a.test_module import ModWithTwoSubmodsAndTensor, SimpleTest 134 135 class SpecialGraphModule(torch.fx.GraphModule): 136 def __init__(self, root, graph, info): 137 super().__init__(root, graph) 138 self.info = info 139 140 sub_module = SimpleTest() 141 module = ModWithTwoSubmodsAndTensor( 142 torch.ones(3), 143 sub_module, 144 sub_module, 145 ) 146 tracer = TestAllLeafModulesTracer() 147 graph = tracer.trace(module) 148 149 self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer) 150 151 gm = SpecialGraphModule(module, graph, "secret") 152 self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer) 153 154 f = BytesIO() 155 with PackageExporter(f) as pe: 156 pe.intern("**") 157 pe.save_pickle("model", "model.pkl", gm) 158 f.seek(0) 159 160 pi = PackageImporter(f) 161 loaded_gm = pi.load_pickle("model", "model.pkl") 162 self.assertEqual( 163 type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__ 164 ) 165 self.assertEqual(loaded_gm.info, "secret") 166 167 input_x = torch.randn(3) 168 self.assertEqual(loaded_gm(input_x), gm(input_x)) 169 170 def test_package_fx_wrap(self): 171 class TestModule(torch.nn.Module): 172 def __init__(self) -> None: 173 super().__init__() 174 175 def forward(self, a): 176 return len(a) 177 178 traced = torch.fx.symbolic_trace(TestModule()) 179 180 f = BytesIO() 181 with torch.package.PackageExporter(f) as pe: 182 pe.save_pickle("model", "model.pkl", traced) 183 f.seek(0) 184 185 pi = PackageImporter(f) 186 loaded_traced = pi.load_pickle("model", "model.pkl") 187 input = torch.rand(2, 3) 188 self.assertEqual(loaded_traced(input), traced(input)) 189 190 191if __name__ == "__main__": 192 run_tests() 193