1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch._C 10*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.mobile import _load_for_lite_interpreter 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 13*da0073e9SAndroid Build Coastguard Worker find_library_location, 14*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 15*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 16*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 17*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 18*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 19*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 20*da0073e9SAndroid Build Coastguard Worker) 21*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 25*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 26*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 29*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 30*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 31*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 32*da0073e9SAndroid Build Coastguard Worker "instead." 33*da0073e9SAndroid Build Coastguard Worker ) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef to_test_backend(module, method_compile_spec): 37*da0073e9SAndroid Build Coastguard Worker return torch._C._jit_to_backend( 38*da0073e9SAndroid Build Coastguard Worker "test_backend", module, {"forward": method_compile_spec} 39*da0073e9SAndroid Build Coastguard Worker ) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerdef to_test_backend_multi(module, method_compile_spec): 43*da0073e9SAndroid Build Coastguard Worker return torch._C._jit_to_backend("test_backend", module, method_compile_spec) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerdef to_test_backend_selective(module, method_compile_spec, submodules): 47*da0073e9SAndroid Build Coastguard Worker def _to_test_backend(module): 48*da0073e9SAndroid Build Coastguard Worker return to_test_backend(module, method_compile_spec) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerclass BasicModule(torch.nn.Module): 54*da0073e9SAndroid Build Coastguard Worker """ 55*da0073e9SAndroid Build Coastguard Worker A simple Module used to test to_backend lowering machinery. 56*da0073e9SAndroid Build Coastguard Worker """ 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 59*da0073e9SAndroid Build Coastguard Worker return self.accum(x, h), self.sub_accum(x, h) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def accum(self, x, h): 62*da0073e9SAndroid Build Coastguard Worker return x + h 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def sub_accum(self, x, h): 65*da0073e9SAndroid Build Coastguard Worker return x - h 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. 69*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 70*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 71*da0073e9SAndroid Build Coastguard Worker "Non-portable load_library call used in test", 72*da0073e9SAndroid Build Coastguard Worker) 73*da0073e9SAndroid Build Coastguard Workerclass JitBackendTestCase(JitTestCase): 74*da0073e9SAndroid Build Coastguard Worker """ 75*da0073e9SAndroid Build Coastguard Worker A common base class for JIT backend tests that contains common utility 76*da0073e9SAndroid Build Coastguard Worker functions for output comparison and serialization/deserialization. 77*da0073e9SAndroid Build Coastguard Worker """ 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def setUp(self): 80*da0073e9SAndroid Build Coastguard Worker super().setUp() 81*da0073e9SAndroid Build Coastguard Worker lib_file_path = find_library_location("libjitbackend_test.so") 82*da0073e9SAndroid Build Coastguard Worker torch.ops.load_library(str(lib_file_path)) 83*da0073e9SAndroid Build Coastguard Worker # Subclasses are expected to set up three variables in their setUp methods: 84*da0073e9SAndroid Build Coastguard Worker # module - a regular, Python version of the module being tested 85*da0073e9SAndroid Build Coastguard Worker # scripted_module - a scripted version of module 86*da0073e9SAndroid Build Coastguard Worker # lowered_module - a version of module lowered to a backend 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def check_function(self, function_name, input): 89*da0073e9SAndroid Build Coastguard Worker """ 90*da0073e9SAndroid Build Coastguard Worker Check that the function named 'function_name' produces the same output using 91*da0073e9SAndroid Build Coastguard Worker Python, regular JIT and the backend for the given 'input'. 92*da0073e9SAndroid Build Coastguard Worker """ 93*da0073e9SAndroid Build Coastguard Worker # Get handles for Python, JIT and backend methods. 94*da0073e9SAndroid Build Coastguard Worker python_method = self.module.__getattribute__(function_name) 95*da0073e9SAndroid Build Coastguard Worker jit_method = self.scripted_module.__getattr__(function_name) 96*da0073e9SAndroid Build Coastguard Worker backend_method = self.lowered_module.__getattr__(function_name) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # Run methods. 99*da0073e9SAndroid Build Coastguard Worker python_output = python_method(*input) 100*da0073e9SAndroid Build Coastguard Worker jit_output = jit_method(*input) 101*da0073e9SAndroid Build Coastguard Worker backend_output = backend_method(*input) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker # The answers returned by Python, JIT and to_backend should all match. 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(python_output, backend_output) 105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_output, backend_output) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def save_load(self): 108*da0073e9SAndroid Build Coastguard Worker """ 109*da0073e9SAndroid Build Coastguard Worker Save and load the lowered module. 110*da0073e9SAndroid Build Coastguard Worker """ 111*da0073e9SAndroid Build Coastguard Worker self.lowered_module = self.getExportImportCopy(self.lowered_module) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 114*da0073e9SAndroid Build Coastguard Worker """ 115*da0073e9SAndroid Build Coastguard Worker Stub for correctness tests. 116*da0073e9SAndroid Build Coastguard Worker """ 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 119*da0073e9SAndroid Build Coastguard Worker """ 120*da0073e9SAndroid Build Coastguard Worker Stub for serialization tests. 121*da0073e9SAndroid Build Coastguard Worker """ 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 124*da0073e9SAndroid Build Coastguard Worker """ 125*da0073e9SAndroid Build Coastguard Worker Stub for testing error checking. 126*da0073e9SAndroid Build Coastguard Worker """ 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Workerclass BasicModuleTest(JitBackendTestCase): 130*da0073e9SAndroid Build Coastguard Worker """ 131*da0073e9SAndroid Build Coastguard Worker Tests for BasicModule. 132*da0073e9SAndroid Build Coastguard Worker """ 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker def setUp(self): 135*da0073e9SAndroid Build Coastguard Worker super().setUp() 136*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of BasicModule. 137*da0073e9SAndroid Build Coastguard Worker self.module = BasicModule() 138*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script(BasicModule()) 139*da0073e9SAndroid Build Coastguard Worker self.lowered_module = to_test_backend_multi( 140*da0073e9SAndroid Build Coastguard Worker self.scripted_module, 141*da0073e9SAndroid Build Coastguard Worker {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 142*da0073e9SAndroid Build Coastguard Worker ) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 145*da0073e9SAndroid Build Coastguard Worker # Test execution with backend against Python and JIT. 146*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker # Test all three module methods. 149*da0073e9SAndroid Build Coastguard Worker self.check_function("accum", (input, input)) 150*da0073e9SAndroid Build Coastguard Worker self.check_function("sub_accum", (input, input)) 151*da0073e9SAndroid Build Coastguard Worker self.check_function("forward", (input, input)) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 154*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 155*da0073e9SAndroid Build Coastguard Worker # Lowered module should produce the same outputs. 156*da0073e9SAndroid Build Coastguard Worker self.test_execution() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker # Save the compile spec to compare against the version retrieved after loading. 159*da0073e9SAndroid Build Coastguard Worker pre_compile_spec = self.lowered_module.__getattr__( 160*da0073e9SAndroid Build Coastguard Worker "__loweredModule__" 161*da0073e9SAndroid Build Coastguard Worker ).__getattr__("__method_compile_spec") 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker # Save and load the lowered module. 164*da0073e9SAndroid Build Coastguard Worker self.save_load() 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker # Get the compile spec after loading. 167*da0073e9SAndroid Build Coastguard Worker post_compile_spec = self.lowered_module.__getattr__( 168*da0073e9SAndroid Build Coastguard Worker "__loweredModule__" 169*da0073e9SAndroid Build Coastguard Worker ).__getattr__("__method_compile_spec") 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker # Compile specs should match. 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pre_compile_spec, post_compile_spec) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker # Loaded module should produce the same outputs. 175*da0073e9SAndroid Build Coastguard Worker self.test_execution() 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Workerclass BasicModuleUnavailableTest(JitBackendTestCase): 179*da0073e9SAndroid Build Coastguard Worker """ 180*da0073e9SAndroid Build Coastguard Worker Tests for BasicModule with a backend that is not available. 181*da0073e9SAndroid Build Coastguard Worker Fundamentally: 182*da0073e9SAndroid Build Coastguard Worker * _jit_to_backend is successful. 183*da0073e9SAndroid Build Coastguard Worker * Execution fails with an exception. 184*da0073e9SAndroid Build Coastguard Worker * Saving is successful. 185*da0073e9SAndroid Build Coastguard Worker * Loading fails with an exception. 186*da0073e9SAndroid Build Coastguard Worker """ 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def setUp(self): 189*da0073e9SAndroid Build Coastguard Worker super().setUp() 190*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of BasicModule. 191*da0073e9SAndroid Build Coastguard Worker self.module = BasicModule() 192*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script(BasicModule()) 193*da0073e9SAndroid Build Coastguard Worker self.lowered_module = torch._C._jit_to_backend( 194*da0073e9SAndroid Build Coastguard Worker "test_backend_unavailable", 195*da0073e9SAndroid Build Coastguard Worker self.scripted_module, 196*da0073e9SAndroid Build Coastguard Worker {"forward": {"": ""}}, 197*da0073e9SAndroid Build Coastguard Worker ) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 200*da0073e9SAndroid Build Coastguard Worker # Test execution with backend fails because the backend that is not available. 201*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker # Test exception is thrown. 204*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 205*da0073e9SAndroid Build Coastguard Worker Exception, 206*da0073e9SAndroid Build Coastguard Worker r"Backend is not available.", 207*da0073e9SAndroid Build Coastguard Worker 'raise Exception("Backend is not available."', 208*da0073e9SAndroid Build Coastguard Worker ): 209*da0073e9SAndroid Build Coastguard Worker backend_method = self.lowered_module.__getattr__("forward") 210*da0073e9SAndroid Build Coastguard Worker backend_output = backend_method(*(input, input)) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 213*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 214*da0073e9SAndroid Build Coastguard Worker # Test that saving the lowered module is OK but loading fails because the backend is not available. 215*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 216*da0073e9SAndroid Build Coastguard Worker torch.jit.save(self.lowered_module, buffer) 217*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 218*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 219*da0073e9SAndroid Build Coastguard Worker Exception, 220*da0073e9SAndroid Build Coastguard Worker r"Backend is not available.", 221*da0073e9SAndroid Build Coastguard Worker 'raise Exception("Backend is not available."', 222*da0073e9SAndroid Build Coastguard Worker ): 223*da0073e9SAndroid Build Coastguard Worker imported = torch.jit.load(buffer) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workerclass NestedModuleTest(JitBackendTestCase): 227*da0073e9SAndroid Build Coastguard Worker """ 228*da0073e9SAndroid Build Coastguard Worker Tests for NestedModule that check that a module lowered to a backend can be used 229*da0073e9SAndroid Build Coastguard Worker as a submodule. 230*da0073e9SAndroid Build Coastguard Worker """ 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker class NestedModule(torch.nn.Module): 233*da0073e9SAndroid Build Coastguard Worker """ 234*da0073e9SAndroid Build Coastguard Worker A Module with one submodule that is used to test that lowered Modules 235*da0073e9SAndroid Build Coastguard Worker can be used as submodules. 236*da0073e9SAndroid Build Coastguard Worker """ 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def __init__(self, submodule): 239*da0073e9SAndroid Build Coastguard Worker super().__init__() 240*da0073e9SAndroid Build Coastguard Worker self.submodule = submodule 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 243*da0073e9SAndroid Build Coastguard Worker return self.submodule.forward(x, h) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker def setUp(self): 246*da0073e9SAndroid Build Coastguard Worker super().setUp() 247*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of NestedModule. 248*da0073e9SAndroid Build Coastguard Worker # Both modules in self.module are regular Python modules. 249*da0073e9SAndroid Build Coastguard Worker self.module = NestedModuleTest.NestedModule(BasicModule()) 250*da0073e9SAndroid Build Coastguard Worker # Both modules in self.scripted_module are ScriptModules. 251*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script( 252*da0073e9SAndroid Build Coastguard Worker NestedModuleTest.NestedModule(BasicModule()) 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker # First, script another instance of NestedModule with share_types=False so that it can be 256*da0073e9SAndroid Build Coastguard Worker # selectively lowered without modifying the type of self.scripted_module. 257*da0073e9SAndroid Build Coastguard Worker lowered_module = to_test_backend_multi( 258*da0073e9SAndroid Build Coastguard Worker torch.jit.script(BasicModule()), 259*da0073e9SAndroid Build Coastguard Worker {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 260*da0073e9SAndroid Build Coastguard Worker ) 261*da0073e9SAndroid Build Coastguard Worker # self.lowered_module is a ScriptModule, but its submodule is a lowered module. 262*da0073e9SAndroid Build Coastguard Worker self.lowered_module = torch.jit.script( 263*da0073e9SAndroid Build Coastguard Worker NestedModuleTest.NestedModule(lowered_module) 264*da0073e9SAndroid Build Coastguard Worker ) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 267*da0073e9SAndroid Build Coastguard Worker # Test execution with backend against Python and JIT. 268*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker # Test forward. 271*da0073e9SAndroid Build Coastguard Worker self.check_function("forward", (input, input)) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 274*da0073e9SAndroid Build Coastguard Worker # Lowered module should produce the same outputs. 275*da0073e9SAndroid Build Coastguard Worker self.test_execution() 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker # Save and load the lowered module. 278*da0073e9SAndroid Build Coastguard Worker self.save_load() 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # Loaded module should produce the same outputs. 281*da0073e9SAndroid Build Coastguard Worker self.test_execution() 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Workerclass SelectiveLoweringTest(JitBackendTestCase): 285*da0073e9SAndroid Build Coastguard Worker """ 286*da0073e9SAndroid Build Coastguard Worker Tests for the selective lowering API. 287*da0073e9SAndroid Build Coastguard Worker """ 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker class OuterModule(torch.nn.Module): 290*da0073e9SAndroid Build Coastguard Worker def __init__(self, sub1, sub2, other): 291*da0073e9SAndroid Build Coastguard Worker super().__init__() 292*da0073e9SAndroid Build Coastguard Worker self.sub1 = sub1 293*da0073e9SAndroid Build Coastguard Worker self.sub2 = sub2 294*da0073e9SAndroid Build Coastguard Worker self.other = other 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 297*da0073e9SAndroid Build Coastguard Worker # Call the module that will be lowered directly to test 298*da0073e9SAndroid Build Coastguard Worker # type remapping in modules that are not its parent. 299*da0073e9SAndroid Build Coastguard Worker a, b = self.sub1.submodule.forward(x, y) 300*da0073e9SAndroid Build Coastguard Worker c, d = self.sub2.forward(x, y) 301*da0073e9SAndroid Build Coastguard Worker e, f = self.other.forward(x, y) 302*da0073e9SAndroid Build Coastguard Worker return a + c + e, b + d + f 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker class MiddleModule(torch.nn.Module): 305*da0073e9SAndroid Build Coastguard Worker def __init__(self, submodule): 306*da0073e9SAndroid Build Coastguard Worker super().__init__() 307*da0073e9SAndroid Build Coastguard Worker self.submodule = submodule 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 310*da0073e9SAndroid Build Coastguard Worker return self.submodule.forward(x, y) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def setUp(self): 313*da0073e9SAndroid Build Coastguard Worker super().setUp() 314*da0073e9SAndroid Build Coastguard Worker OuterModule = SelectiveLoweringTest.OuterModule 315*da0073e9SAndroid Build Coastguard Worker MiddleModule = SelectiveLoweringTest.MiddleModule 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker def script_without_type_sharing(mod): 318*da0073e9SAndroid Build Coastguard Worker return torch.jit._recursive.create_script_module( 319*da0073e9SAndroid Build Coastguard Worker mod, torch.jit._recursive.infer_methods_to_compile, share_types=False 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of a hierarchy that looks like this: 323*da0073e9SAndroid Build Coastguard Worker # --------- OuterModule -------- 324*da0073e9SAndroid Build Coastguard Worker # | | | 325*da0073e9SAndroid Build Coastguard Worker # MiddleModule MiddleModule MiddleModule 326*da0073e9SAndroid Build Coastguard Worker # | | | 327*da0073e9SAndroid Build Coastguard Worker # BasicModule BasicModule BasicModule 328*da0073e9SAndroid Build Coastguard Worker # 329*da0073e9SAndroid Build Coastguard Worker # Two BasicModules will be lowered and the third will not. 330*da0073e9SAndroid Build Coastguard Worker self.module = OuterModule( 331*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 332*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 333*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 334*da0073e9SAndroid Build Coastguard Worker ) 335*da0073e9SAndroid Build Coastguard Worker self.scripted_module = script_without_type_sharing( 336*da0073e9SAndroid Build Coastguard Worker OuterModule( 337*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 338*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 339*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 340*da0073e9SAndroid Build Coastguard Worker ) 341*da0073e9SAndroid Build Coastguard Worker ) 342*da0073e9SAndroid Build Coastguard Worker self.lowered_module = script_without_type_sharing( 343*da0073e9SAndroid Build Coastguard Worker OuterModule( 344*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 345*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 346*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker self.lowered_module = to_test_backend_selective( 350*da0073e9SAndroid Build Coastguard Worker self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"] 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 354*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5) 355*da0073e9SAndroid Build Coastguard Worker self.check_function("forward", (input, input)) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker self.test_selective_lowering_type_remap() 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 360*da0073e9SAndroid Build Coastguard Worker self.test_execution() 361*da0073e9SAndroid Build Coastguard Worker self.save_load() 362*da0073e9SAndroid Build Coastguard Worker self.test_execution() 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker self.test_selective_lowering_type_remap() 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def test_selective_lowering_type_remap(self): 367*da0073e9SAndroid Build Coastguard Worker """ 368*da0073e9SAndroid Build Coastguard Worker Check that type remapping and replacement occurred during selective lowering. 369*da0073e9SAndroid Build Coastguard Worker """ 370*da0073e9SAndroid Build Coastguard Worker # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it 371*da0073e9SAndroid Build Coastguard Worker # calling the lowered module directly. 372*da0073e9SAndroid Build Coastguard Worker FileCheck().check("OuterModule").check("BasicModule").run( 373*da0073e9SAndroid Build Coastguard Worker self.scripted_module.graph 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker FileCheck().check("OuterModule").check_not( 376*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 377*da0073e9SAndroid Build Coastguard Worker ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. 380*da0073e9SAndroid Build Coastguard Worker FileCheck().check("MiddleModule").check("BasicModule").check_not( 381*da0073e9SAndroid Build Coastguard Worker "LoweredWrapper.test_backend" 382*da0073e9SAndroid Build Coastguard Worker ).run(self.scripted_module.sub1.graph) 383*da0073e9SAndroid Build Coastguard Worker FileCheck().check("MiddleModule").check_not( 384*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 385*da0073e9SAndroid Build Coastguard Worker ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker FileCheck().check("MiddleModule").check("BasicModule").check_not( 388*da0073e9SAndroid Build Coastguard Worker "LoweredWrapper.test_backend" 389*da0073e9SAndroid Build Coastguard Worker ).run(self.scripted_module.sub2.graph) 390*da0073e9SAndroid Build Coastguard Worker FileCheck().check("MiddleModule").check_not( 391*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 392*da0073e9SAndroid Build Coastguard Worker ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute 395*da0073e9SAndroid Build Coastguard Worker # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend, 396*da0073e9SAndroid Build Coastguard Worker # the TorchBind class for executing functions on the test JIT backend. 397*da0073e9SAndroid Build Coastguard Worker FileCheck().check("LoweredModule.test_backend").check( 398*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 399*da0073e9SAndroid Build Coastguard Worker ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker FileCheck().check("LoweredModule.test_backend").check( 402*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 403*da0073e9SAndroid Build Coastguard Worker ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph) 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. 406*da0073e9SAndroid Build Coastguard Worker FileCheck().check("MiddleModule").check("BasicModule").check_not( 407*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 408*da0073e9SAndroid Build Coastguard Worker ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph) 409*da0073e9SAndroid Build Coastguard Worker FileCheck().check("BasicModule").check_not( 410*da0073e9SAndroid Build Coastguard Worker "__torch__.torch.classes.__backends__.test_backend" 411*da0073e9SAndroid Build Coastguard Worker ).check_not("LoweredModule.test_backend").run( 412*da0073e9SAndroid Build Coastguard Worker self.scripted_module.other.submodule.graph 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 416*da0073e9SAndroid Build Coastguard Worker """ 417*da0073e9SAndroid Build Coastguard Worker Check errors associated with selective lowering. 418*da0073e9SAndroid Build Coastguard Worker """ 419*da0073e9SAndroid Build Coastguard Worker # Check error messages thrown when attempting to lower something that is not a ScriptModule. 420*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 421*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Object .* is not a ScriptModule", "" 422*da0073e9SAndroid Build Coastguard Worker ): 423*da0073e9SAndroid Build Coastguard Worker to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker MiddleModule = SelectiveLoweringTest.MiddleModule 426*da0073e9SAndroid Build Coastguard Worker mod = MiddleModule(BasicModule()) 427*da0073e9SAndroid Build Coastguard Worker mod.new_attr = 3 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 430*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Attribute named new_attr is not a Module", "" 431*da0073e9SAndroid Build Coastguard Worker ): 432*da0073e9SAndroid Build Coastguard Worker to_test_backend_selective( 433*da0073e9SAndroid Build Coastguard Worker torch.jit.script(mod), {"forward": ""}, ["new_attr"] 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # Check error message thrown when module hierarchy doesn't have unique types. 437*da0073e9SAndroid Build Coastguard Worker OuterModule = SelectiveLoweringTest.OuterModule 438*da0073e9SAndroid Build Coastguard Worker mod = OuterModule( 439*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 440*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 441*da0073e9SAndroid Build Coastguard Worker MiddleModule(BasicModule()), 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 445*da0073e9SAndroid Build Coastguard Worker RuntimeError, 446*da0073e9SAndroid Build Coastguard Worker r"Selective lowering is only supported for module hierarchies with unique types", 447*da0073e9SAndroid Build Coastguard Worker "", 448*da0073e9SAndroid Build Coastguard Worker ): 449*da0073e9SAndroid Build Coastguard Worker to_test_backend_selective( 450*da0073e9SAndroid Build Coastguard Worker torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"] 451*da0073e9SAndroid Build Coastguard Worker ) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. 455*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 456*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 457*da0073e9SAndroid Build Coastguard Worker "Non-portable load_library call used in test", 458*da0073e9SAndroid Build Coastguard Worker) 459*da0073e9SAndroid Build Coastguard Workerclass TestBackends(JitTestCase): 460*da0073e9SAndroid Build Coastguard Worker """ 461*da0073e9SAndroid Build Coastguard Worker This class wraps and invokes all subclasses of JitBackendTestCase so that each one 462*da0073e9SAndroid Build Coastguard Worker does not have to be individually imported in test_jit.py. 463*da0073e9SAndroid Build Coastguard Worker """ 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 466*da0073e9SAndroid Build Coastguard Worker super().__init__(name) 467*da0073e9SAndroid Build Coastguard Worker self.basic_module_test = BasicModuleTest(name) 468*da0073e9SAndroid Build Coastguard Worker self.basic_module_unavailable_test = BasicModuleUnavailableTest(name) 469*da0073e9SAndroid Build Coastguard Worker self.nested_module_test = NestedModuleTest(name) 470*da0073e9SAndroid Build Coastguard Worker self.selective_lowering_test = SelectiveLoweringTest(name) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def setUp(self): 473*da0073e9SAndroid Build Coastguard Worker super().setUp() 474*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_ROCM: 475*da0073e9SAndroid Build Coastguard Worker self.basic_module_test.setUp() 476*da0073e9SAndroid Build Coastguard Worker self.basic_module_unavailable_test.setUp() 477*da0073e9SAndroid Build Coastguard Worker self.nested_module_test.setUp() 478*da0073e9SAndroid Build Coastguard Worker self.selective_lowering_test.setUp() 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 481*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 482*da0073e9SAndroid Build Coastguard Worker self.basic_module_test.test_execution() 483*da0073e9SAndroid Build Coastguard Worker self.basic_module_unavailable_test.test_execution() 484*da0073e9SAndroid Build Coastguard Worker self.nested_module_test.test_execution() 485*da0073e9SAndroid Build Coastguard Worker self.selective_lowering_test.test_execution() 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 488*da0073e9SAndroid Build Coastguard Worker def test_save_load(self): 489*da0073e9SAndroid Build Coastguard Worker self.basic_module_test.test_save_load() 490*da0073e9SAndroid Build Coastguard Worker self.basic_module_unavailable_test.test_save_load() 491*da0073e9SAndroid Build Coastguard Worker self.nested_module_test.test_save_load() 492*da0073e9SAndroid Build Coastguard Worker self.selective_lowering_test.test_save_load() 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 495*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 496*da0073e9SAndroid Build Coastguard Worker self.selective_lowering_test.test_errors() 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker""" 500*da0073e9SAndroid Build Coastguard WorkerUnit Tests for backend with compiler 501*da0073e9SAndroid Build Coastguard WorkerThis test case and the existing TestBackends are separate because they cover different aspects. 502*da0073e9SAndroid Build Coastguard WorkerThe actual backend implementation in this test is different. 503*da0073e9SAndroid Build Coastguard WorkerIt has a simple demo compiler to test the end-to-end flow in mobile. 504*da0073e9SAndroid Build Coastguard WorkerHowever, this test cannot cover the selective_lowering for now, which is covered in TestBackends. 505*da0073e9SAndroid Build Coastguard Worker""" 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Workerclass BasicModuleAdd(torch.nn.Module): 509*da0073e9SAndroid Build Coastguard Worker """ 510*da0073e9SAndroid Build Coastguard Worker A simple add Module used to test to_backend lowering machinery. 511*da0073e9SAndroid Build Coastguard Worker """ 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 514*da0073e9SAndroid Build Coastguard Worker return x + h 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. 518*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 519*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 520*da0073e9SAndroid Build Coastguard Worker "Non-portable load_library call used in test", 521*da0073e9SAndroid Build Coastguard Worker) 522*da0073e9SAndroid Build Coastguard Workerclass JitBackendTestCaseWithCompiler(JitTestCase): 523*da0073e9SAndroid Build Coastguard Worker """ 524*da0073e9SAndroid Build Coastguard Worker A common base class for JIT backend tests with compilers that contains common utility 525*da0073e9SAndroid Build Coastguard Worker functions for output comparison. 526*da0073e9SAndroid Build Coastguard Worker """ 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker def setUp(self): 529*da0073e9SAndroid Build Coastguard Worker super().setUp() 530*da0073e9SAndroid Build Coastguard Worker lib_file_path = find_library_location("libbackend_with_compiler.so") 531*da0073e9SAndroid Build Coastguard Worker torch.ops.load_library(str(lib_file_path)) 532*da0073e9SAndroid Build Coastguard Worker # Subclasses are expected to set up four variables in their setUp methods: 533*da0073e9SAndroid Build Coastguard Worker # module - a regular, Python version of the module being tested 534*da0073e9SAndroid Build Coastguard Worker # scripted_module - a scripted version of module 535*da0073e9SAndroid Build Coastguard Worker # lowered_module - a version of module lowered to a backend 536*da0073e9SAndroid Build Coastguard Worker # mobile_module - a module with a format that Pytorch Mobile can execute 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def check_forward(self, input): 539*da0073e9SAndroid Build Coastguard Worker """ 540*da0073e9SAndroid Build Coastguard Worker Check that the forward function produces the same output using 541*da0073e9SAndroid Build Coastguard Worker Python, regular JIT, the backend, and mobile for the given 'input'. 542*da0073e9SAndroid Build Coastguard Worker """ 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker # Get outputs from forward. 545*da0073e9SAndroid Build Coastguard Worker python_output = self.module.forward(*input) 546*da0073e9SAndroid Build Coastguard Worker jit_output = self.scripted_module.forward(*input) 547*da0073e9SAndroid Build Coastguard Worker backend_output = self.lowered_module(*input) 548*da0073e9SAndroid Build Coastguard Worker mobile_output = self.mobile_module(*input) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker # The answers returned by Python, JIT, to_backend, and mobile should all match. 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(python_output, backend_output) 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_output, backend_output) 553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mobile_output, backend_output) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 556*da0073e9SAndroid Build Coastguard Worker """ 557*da0073e9SAndroid Build Coastguard Worker Stub for correctness tests. 558*da0073e9SAndroid Build Coastguard Worker """ 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 561*da0073e9SAndroid Build Coastguard Worker """ 562*da0073e9SAndroid Build Coastguard Worker Stub for testing error checking. 563*da0073e9SAndroid Build Coastguard Worker """ 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Workerclass BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler): 567*da0073e9SAndroid Build Coastguard Worker """ 568*da0073e9SAndroid Build Coastguard Worker Tests for BasicModuleAdd. 569*da0073e9SAndroid Build Coastguard Worker """ 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def setUp(self): 572*da0073e9SAndroid Build Coastguard Worker super().setUp() 573*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of BasicModuleAdd. 574*da0073e9SAndroid Build Coastguard Worker self.module = BasicModuleAdd() 575*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script(BasicModuleAdd()) 576*da0073e9SAndroid Build Coastguard Worker compile_spec = { 577*da0073e9SAndroid Build Coastguard Worker "forward": { 578*da0073e9SAndroid Build Coastguard Worker "input_shapes": "((1, 1, 320, 240), (1, 3))", 579*da0073e9SAndroid Build Coastguard Worker "some_other_option": "True", 580*da0073e9SAndroid Build Coastguard Worker }, 581*da0073e9SAndroid Build Coastguard Worker } 582*da0073e9SAndroid Build Coastguard Worker self.lowered_module = torch._C._jit_to_backend( 583*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", self.scripted_module, compile_spec 584*da0073e9SAndroid Build Coastguard Worker ) 585*da0073e9SAndroid Build Coastguard Worker # Create mobile version of BasicModuleAdd 586*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter()) 587*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 588*da0073e9SAndroid Build Coastguard Worker self.mobile_module = _load_for_lite_interpreter(buffer) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 591*da0073e9SAndroid Build Coastguard Worker # Test execution with backend against Python and JIT. 592*da0073e9SAndroid Build Coastguard Worker input = torch.ones(1, dtype=torch.float) 593*da0073e9SAndroid Build Coastguard Worker self.check_forward((input, input)) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Workerclass ErrorMessagesWithCompiler(JitBackendTestCase): 597*da0073e9SAndroid Build Coastguard Worker """ 598*da0073e9SAndroid Build Coastguard Worker Tests for errors that occur with compiler, specifically: 599*da0073e9SAndroid Build Coastguard Worker * an operator is not supported by the backend 600*da0073e9SAndroid Build Coastguard Worker """ 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker class ModuleNotSupported(torch.nn.Module): 603*da0073e9SAndroid Build Coastguard Worker """ 604*da0073e9SAndroid Build Coastguard Worker A module with an operator that is not supported. 605*da0073e9SAndroid Build Coastguard Worker """ 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 608*da0073e9SAndroid Build Coastguard Worker return x * h 609*da0073e9SAndroid Build Coastguard Worker self._loweredmodule.forward() 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 612*da0073e9SAndroid Build Coastguard Worker scripted_module_n = torch.jit.script( 613*da0073e9SAndroid Build Coastguard Worker ErrorMessagesWithCompiler.ModuleNotSupported() 614*da0073e9SAndroid Build Coastguard Worker ) 615*da0073e9SAndroid Build Coastguard Worker # Test exception is thrown when lowering a module with an unsupported operator 616*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 617*da0073e9SAndroid Build Coastguard Worker RuntimeError, 618*da0073e9SAndroid Build Coastguard Worker # Special escape characters are replaced with '.' 619*da0073e9SAndroid Build Coastguard Worker r"""The node of aten::mul is not supported in this compiler. .* 620*da0073e9SAndroid Build Coastguard Worker def forward.self, x, h.: 621*da0073e9SAndroid Build Coastguard Worker return x . h 622*da0073e9SAndroid Build Coastguard Worker ~~~~~ <--- HERE 623*da0073e9SAndroid Build Coastguard Worker self._loweredmodule.forward.. 624*da0073e9SAndroid Build Coastguard Worker""", 625*da0073e9SAndroid Build Coastguard Worker "", 626*da0073e9SAndroid Build Coastguard Worker ): 627*da0073e9SAndroid Build Coastguard Worker lowered_module_n = torch._C._jit_to_backend( 628*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}} 629*da0073e9SAndroid Build Coastguard Worker ) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Workerclass CompModuleTestWithCompiler(JitBackendTestCase): 633*da0073e9SAndroid Build Coastguard Worker """ 634*da0073e9SAndroid Build Coastguard Worker Tests for CompModule, which is a module with two lowered submodules 635*da0073e9SAndroid Build Coastguard Worker """ 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker class BasicModuleSub(torch.nn.Module): 638*da0073e9SAndroid Build Coastguard Worker """ 639*da0073e9SAndroid Build Coastguard Worker A simple subtraction Module to be used in CompModule. 640*da0073e9SAndroid Build Coastguard Worker """ 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 643*da0073e9SAndroid Build Coastguard Worker return x - h 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker class CompModule(torch.nn.Module): 646*da0073e9SAndroid Build Coastguard Worker """ 647*da0073e9SAndroid Build Coastguard Worker A module with two lowered submodules. 648*da0073e9SAndroid Build Coastguard Worker """ 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker def __init__(self, addmodule, submodule): 651*da0073e9SAndroid Build Coastguard Worker super().__init__() 652*da0073e9SAndroid Build Coastguard Worker self.lowered_add = addmodule 653*da0073e9SAndroid Build Coastguard Worker self.lowered_sub = submodule 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, s): 656*da0073e9SAndroid Build Coastguard Worker c = self.lowered_add.forward(a, b) 657*da0073e9SAndroid Build Coastguard Worker d = self.lowered_sub.forward(a, b) 658*da0073e9SAndroid Build Coastguard Worker y = s * (c * d) 659*da0073e9SAndroid Build Coastguard Worker return y 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker def setUp(self): 662*da0073e9SAndroid Build Coastguard Worker super().setUp() 663*da0073e9SAndroid Build Coastguard Worker # Create Python and JIT versions of CompModule with lowered submodules. 664*da0073e9SAndroid Build Coastguard Worker compile_spec = { 665*da0073e9SAndroid Build Coastguard Worker "forward": { 666*da0073e9SAndroid Build Coastguard Worker "input_shapes": "((1, 1, 320, 240), (1, 3))", 667*da0073e9SAndroid Build Coastguard Worker "some_other_option": "True", 668*da0073e9SAndroid Build Coastguard Worker }, 669*da0073e9SAndroid Build Coastguard Worker } 670*da0073e9SAndroid Build Coastguard Worker lowered_add = torch._C._jit_to_backend( 671*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", 672*da0073e9SAndroid Build Coastguard Worker torch.jit.script(BasicModuleAdd()), 673*da0073e9SAndroid Build Coastguard Worker compile_spec, 674*da0073e9SAndroid Build Coastguard Worker ) 675*da0073e9SAndroid Build Coastguard Worker lowered_sub = torch._C._jit_to_backend( 676*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", 677*da0073e9SAndroid Build Coastguard Worker torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()), 678*da0073e9SAndroid Build Coastguard Worker {"forward": {"": ""}}, 679*da0073e9SAndroid Build Coastguard Worker ) 680*da0073e9SAndroid Build Coastguard Worker self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) 681*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script( 682*da0073e9SAndroid Build Coastguard Worker CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) 683*da0073e9SAndroid Build Coastguard Worker ) 684*da0073e9SAndroid Build Coastguard Worker # No backend version of CompModule currently, so this is filler. 685*da0073e9SAndroid Build Coastguard Worker self.lowered_module = self.scripted_module 686*da0073e9SAndroid Build Coastguard Worker # Create a mobile version of CompModule from JIT version 687*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) 688*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 689*da0073e9SAndroid Build Coastguard Worker self.mobile_module = _load_for_lite_interpreter(buffer) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 692*da0073e9SAndroid Build Coastguard Worker # Test execution with backend against Python and JIT. 693*da0073e9SAndroid Build Coastguard Worker input1 = torch.ones(1, dtype=torch.float) 694*da0073e9SAndroid Build Coastguard Worker input2 = torch.ones(1, dtype=torch.float) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker # Test forward. 697*da0073e9SAndroid Build Coastguard Worker self.check_function("forward", (input1, input2, input2)) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. 701*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 702*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, 703*da0073e9SAndroid Build Coastguard Worker "Non-portable load_library call used in test", 704*da0073e9SAndroid Build Coastguard Worker) 705*da0073e9SAndroid Build Coastguard Workerclass TestBackendsWithCompiler(JitTestCase): 706*da0073e9SAndroid Build Coastguard Worker """ 707*da0073e9SAndroid Build Coastguard Worker This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler 708*da0073e9SAndroid Build Coastguard Worker so that each one does not have to be individually imported in test_jit.py. 709*da0073e9SAndroid Build Coastguard Worker """ 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def __init__(self, name): 712*da0073e9SAndroid Build Coastguard Worker super().__init__(name) 713*da0073e9SAndroid Build Coastguard Worker self.basic_module_compiler_test = BasicModuleTestWithCompiler(name) 714*da0073e9SAndroid Build Coastguard Worker self.error_module_compiler_test = ErrorMessagesWithCompiler(name) 715*da0073e9SAndroid Build Coastguard Worker self.comp_module_compiler_test = CompModuleTestWithCompiler(name) 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker def setUp(self): 718*da0073e9SAndroid Build Coastguard Worker super().setUp() 719*da0073e9SAndroid Build Coastguard Worker self.basic_module_compiler_test.setUp() 720*da0073e9SAndroid Build Coastguard Worker self.error_module_compiler_test.setUp() 721*da0073e9SAndroid Build Coastguard Worker self.comp_module_compiler_test.setUp() 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 724*da0073e9SAndroid Build Coastguard Worker self.basic_module_compiler_test.test_execution() 725*da0073e9SAndroid Build Coastguard Worker self.comp_module_compiler_test.test_execution() 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker def test_errors(self): 728*da0073e9SAndroid Build Coastguard Worker self.error_module_compiler_test.test_errors() 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Workerclass CompModuleTestSameNameWithCompiler(JitBackendTestCase): 732*da0073e9SAndroid Build Coastguard Worker """ 733*da0073e9SAndroid Build Coastguard Worker Tests for CompModule, which is a module with two lowered submodules with same module name 734*da0073e9SAndroid Build Coastguard Worker """ 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker class ModuleAdd(torch.nn.Module): 737*da0073e9SAndroid Build Coastguard Worker """ 738*da0073e9SAndroid Build Coastguard Worker A simple Module used to test to_backend lowering machinery. 739*da0073e9SAndroid Build Coastguard Worker """ 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 742*da0073e9SAndroid Build Coastguard Worker return x + h 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker class CompModule(torch.nn.Module): 745*da0073e9SAndroid Build Coastguard Worker """ 746*da0073e9SAndroid Build Coastguard Worker A module with two lowered submodules. 747*da0073e9SAndroid Build Coastguard Worker """ 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 750*da0073e9SAndroid Build Coastguard Worker super().__init__() 751*da0073e9SAndroid Build Coastguard Worker compile_spec = { 752*da0073e9SAndroid Build Coastguard Worker "forward": { 753*da0073e9SAndroid Build Coastguard Worker "some_other_option": "True", 754*da0073e9SAndroid Build Coastguard Worker }, 755*da0073e9SAndroid Build Coastguard Worker } 756*da0073e9SAndroid Build Coastguard Worker self.add = torch._C._jit_to_backend( 757*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", 758*da0073e9SAndroid Build Coastguard Worker torch.jit.script(ModuleAdd()), # noqa: F821 759*da0073e9SAndroid Build Coastguard Worker compile_spec, 760*da0073e9SAndroid Build Coastguard Worker ) 761*da0073e9SAndroid Build Coastguard Worker self.sub = torch._C._jit_to_backend( 762*da0073e9SAndroid Build Coastguard Worker "backend_with_compiler_demo", 763*da0073e9SAndroid Build Coastguard Worker torch.jit.script(ModuleAdd()), # noqa: F821 764*da0073e9SAndroid Build Coastguard Worker compile_spec, 765*da0073e9SAndroid Build Coastguard Worker ) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, s: int): 768*da0073e9SAndroid Build Coastguard Worker c = self.add.forward(a, b) 769*da0073e9SAndroid Build Coastguard Worker d = self.sub.forward(a, b) 770*da0073e9SAndroid Build Coastguard Worker y = s * (c * d) 771*da0073e9SAndroid Build Coastguard Worker return y 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker def setUp(self): 774*da0073e9SAndroid Build Coastguard Worker super().setUp() 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker self.module = CompModule() # noqa: F821 777*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script(self.module) 778*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) 779*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 780*da0073e9SAndroid Build Coastguard Worker self.mobile_module = _load_for_lite_interpreter(buffer) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker def test_execution(self): 783*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1) 784*da0073e9SAndroid Build Coastguard Worker b = 3 * torch.ones(1) 785*da0073e9SAndroid Build Coastguard Worker s = 3 786*da0073e9SAndroid Build Coastguard Worker # Test forward. 787*da0073e9SAndroid Build Coastguard Worker self.check_function("forward", (a, b, s)) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Workerclass AddedAttributesTest(JitBackendTestCase): 791*da0073e9SAndroid Build Coastguard Worker """ 792*da0073e9SAndroid Build Coastguard Worker Tests for adding attributes to a model after lowering. 793*da0073e9SAndroid Build Coastguard Worker """ 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker def setUp(self): 796*da0073e9SAndroid Build Coastguard Worker super().setUp() 797*da0073e9SAndroid Build Coastguard Worker # Create Python, JIT and backend versions of BasicModule. 798*da0073e9SAndroid Build Coastguard Worker self.module = BasicModule() 799*da0073e9SAndroid Build Coastguard Worker self.scripted_module = torch.jit.script(BasicModule()) 800*da0073e9SAndroid Build Coastguard Worker self.lowered_module = to_test_backend_multi( 801*da0073e9SAndroid Build Coastguard Worker self.scripted_module, 802*da0073e9SAndroid Build Coastguard Worker {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, 803*da0073e9SAndroid Build Coastguard Worker ) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def test_attribute(self): 806*da0073e9SAndroid Build Coastguard Worker input = [(torch.ones(5),)] 807*da0073e9SAndroid Build Coastguard Worker pre_bundled = self.lowered_module(*input[0]) 808*da0073e9SAndroid Build Coastguard Worker # Attach bundled inputs which adds several attributes and functions to the model 809*da0073e9SAndroid Build Coastguard Worker self.lowered_module = ( 810*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 811*da0073e9SAndroid Build Coastguard Worker lowered_module, input # noqa: F821 812*da0073e9SAndroid Build Coastguard Worker ) 813*da0073e9SAndroid Build Coastguard Worker ) 814*da0073e9SAndroid Build Coastguard Worker post_bundled = self.lowered_module( 815*da0073e9SAndroid Build Coastguard Worker *self.lowered_module.get_all_bundled_inputs()[0] 816*da0073e9SAndroid Build Coastguard Worker ) 817*da0073e9SAndroid Build Coastguard Worker # Save and load the lowered module. 818*da0073e9SAndroid Build Coastguard Worker self.save_load() 819*da0073e9SAndroid Build Coastguard Worker # Use bundled after save and load to prove its preserved 820*da0073e9SAndroid Build Coastguard Worker post_load = self.lowered_module( 821*da0073e9SAndroid Build Coastguard Worker *self.lowered_module.get_all_bundled_inputs()[0] 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pre_bundled, post_bundled) 824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(post_bundled, post_load) 825