1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 10*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 11*da0073e9SAndroid Build Coastguard Workerfrom torch.jit._recursive import wrap_cpp_module 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import skipIfNoFBGEMM 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import override_quantized_engine 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 17*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 18*da0073e9SAndroid Build Coastguard Worker skipCUDAMemoryLeakCheckIf, 19*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 20*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 23*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import mkldnn as mkldnn_utils 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workertry: 27*da0073e9SAndroid Build Coastguard Worker import torchvision 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 30*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 31*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 32*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 35*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 36*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 37*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 38*da0073e9SAndroid Build Coastguard Worker "instead." 39*da0073e9SAndroid Build Coastguard Worker ) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerTEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerdef removeExceptions(graph): 45*da0073e9SAndroid Build Coastguard Worker for n in graph.findAllNodes("prim::RaiseException"): 46*da0073e9SAndroid Build Coastguard Worker n.destroy() 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerclass TestFreezing(JitTestCase): 50*da0073e9SAndroid Build Coastguard Worker def test_freeze_module(self): 51*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 52*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 53*da0073e9SAndroid Build Coastguard Worker super().__init__() 54*da0073e9SAndroid Build Coastguard Worker self.a = 1 # folded 55*da0073e9SAndroid Build Coastguard Worker self.b = 1.2 # folded 56*da0073e9SAndroid Build Coastguard Worker self.c = "hello" # folded 57*da0073e9SAndroid Build Coastguard Worker self.c2 = "hi\xA1" # not folded 58*da0073e9SAndroid Build Coastguard Worker self.d = [1, 1] # folded 59*da0073e9SAndroid Build Coastguard Worker self.e = [1.0, 1.1] # folded 60*da0073e9SAndroid Build Coastguard Worker self.f = ["hello", "world"] # folded 61*da0073e9SAndroid Build Coastguard Worker self.f2 = [(1, "Over \u0e55\u0e57 57")] 62*da0073e9SAndroid Build Coastguard Worker self.g = ( 63*da0073e9SAndroid Build Coastguard Worker [1, 2], 64*da0073e9SAndroid Build Coastguard Worker 3.2, 65*da0073e9SAndroid Build Coastguard Worker "4.4", 66*da0073e9SAndroid Build Coastguard Worker torch.tensor([5.5], requires_grad=True), 67*da0073e9SAndroid Build Coastguard Worker ) # folded 68*da0073e9SAndroid Build Coastguard Worker self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]} 69*da0073e9SAndroid Build Coastguard Worker self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]} 70*da0073e9SAndroid Build Coastguard Worker self.t = torch.tensor([1.2, 2.4], requires_grad=True) # folded 71*da0073e9SAndroid Build Coastguard Worker self.ts = [ 72*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 2.0], requires_grad=True), 73*da0073e9SAndroid Build Coastguard Worker torch.tensor([3.0, 4.0], requires_grad=True), 74*da0073e9SAndroid Build Coastguard Worker ] # folded 75*da0073e9SAndroid Build Coastguard Worker self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]] 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 78*da0073e9SAndroid Build Coastguard Worker return ( 79*da0073e9SAndroid Build Coastguard Worker str(self.a) 80*da0073e9SAndroid Build Coastguard Worker + str(self.b) 81*da0073e9SAndroid Build Coastguard Worker + self.c 82*da0073e9SAndroid Build Coastguard Worker + self.c2 83*da0073e9SAndroid Build Coastguard Worker + str(self.d) 84*da0073e9SAndroid Build Coastguard Worker + str(self.e) 85*da0073e9SAndroid Build Coastguard Worker + str(self.f) 86*da0073e9SAndroid Build Coastguard Worker + str(self.f2) 87*da0073e9SAndroid Build Coastguard Worker + str(self.g) 88*da0073e9SAndroid Build Coastguard Worker + str(self.h) 89*da0073e9SAndroid Build Coastguard Worker + str(self.h2) 90*da0073e9SAndroid Build Coastguard Worker + str(self.t) 91*da0073e9SAndroid Build Coastguard Worker + str(self.ts) 92*da0073e9SAndroid Build Coastguard Worker + str(self.tt) 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 96*da0073e9SAndroid Build Coastguard Worker m.eval() 97*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 98*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 99*da0073e9SAndroid Build Coastguard Worker m._c = torch._C._freeze_module(m._c) 100*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 101*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m._c, buffer) 102*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 103*da0073e9SAndroid Build Coastguard Worker m2 = torch.jit.load(buffer) 104*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 105*da0073e9SAndroid Build Coastguard Worker # module m { 106*da0073e9SAndroid Build Coastguard Worker # attributes { 107*da0073e9SAndroid Build Coastguard Worker # tt = ... 108*da0073e9SAndroid Build Coastguard Worker # } 109*da0073e9SAndroid Build Coastguard Worker # ... 110*da0073e9SAndroid Build Coastguard Worker # } 111*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("a")) 112*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("b")) 113*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("c")) 114*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("c2")) 115*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("d")) 116*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("e")) 117*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("f")) 118*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("f2")) 119*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("g")) 120*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("h")) 121*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("h2")) 122*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("t")) 123*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("ts")) 124*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2._c.hasattr("tt")) 125*da0073e9SAndroid Build Coastguard Worker output_f = m2.forward(input) 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_submodule(self): 129*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 130*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 131*da0073e9SAndroid Build Coastguard Worker super().__init__() 132*da0073e9SAndroid Build Coastguard Worker self.a = 11 133*da0073e9SAndroid Build Coastguard Worker self.b = 2 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 136*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker class SubModule2(nn.Module): 139*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 140*da0073e9SAndroid Build Coastguard Worker super().__init__() 141*da0073e9SAndroid Build Coastguard Worker self.a = 12 142*da0073e9SAndroid Build Coastguard Worker self.b = 2 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 145*da0073e9SAndroid Build Coastguard Worker self.b = 30 146*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 149*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 150*da0073e9SAndroid Build Coastguard Worker super().__init__() 151*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() 152*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule2() 153*da0073e9SAndroid Build Coastguard Worker self.a = 3 154*da0073e9SAndroid Build Coastguard Worker self.b = 4 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 157*da0073e9SAndroid Build Coastguard Worker self.b = 20 158*da0073e9SAndroid Build Coastguard Worker return self.sub1(x) + self.a + self.b + self.sub2(x) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 161*da0073e9SAndroid Build Coastguard Worker m.eval() 162*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 163*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 164*da0073e9SAndroid Build Coastguard Worker mf = torch.jit.freeze(m) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 167*da0073e9SAndroid Build Coastguard Worker # module m { 168*da0073e9SAndroid Build Coastguard Worker # attributes { 169*da0073e9SAndroid Build Coastguard Worker # sub2 = ... 170*da0073e9SAndroid Build Coastguard Worker # b = 171*da0073e9SAndroid Build Coastguard Worker # } 172*da0073e9SAndroid Build Coastguard Worker # ... 173*da0073e9SAndroid Build Coastguard Worker # submodule { 174*da0073e9SAndroid Build Coastguard Worker # module m { 175*da0073e9SAndroid Build Coastguard Worker # attributes { 176*da0073e9SAndroid Build Coastguard Worker # sub2 = ... 177*da0073e9SAndroid Build Coastguard Worker # b = 178*da0073e9SAndroid Build Coastguard Worker # } 179*da0073e9SAndroid Build Coastguard Worker # ... 180*da0073e9SAndroid Build Coastguard Worker # } 181*da0073e9SAndroid Build Coastguard Worker # } 182*da0073e9SAndroid Build Coastguard Worker # } 183*da0073e9SAndroid Build Coastguard Worker mf = mf._c 184*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("sub1")) 185*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("b")) 187*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub2")) 188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.hasattr("b")) # verify b is preserved in sub2 189*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.sub2.hasattr("a")) # verify a is removed in sub2 190*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_fork(self): 194*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 195*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 196*da0073e9SAndroid Build Coastguard Worker super().__init__() 197*da0073e9SAndroid Build Coastguard Worker self.a = torch.ones(20, 20) 198*da0073e9SAndroid Build Coastguard Worker self.b = torch.ones(20, 20) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 201*da0073e9SAndroid Build Coastguard Worker return self.a * self.b + x 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 204*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 205*da0073e9SAndroid Build Coastguard Worker super().__init__() 206*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 209*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(self.sub.forward, x) 210*da0073e9SAndroid Build Coastguard Worker y_hat = self.sub(x) 211*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 212*da0073e9SAndroid Build Coastguard Worker return y_hat + y 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 215*da0073e9SAndroid Build Coastguard Worker m.eval() 216*da0073e9SAndroid Build Coastguard Worker input = torch.randn(20, 20) 217*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 218*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 221*da0073e9SAndroid Build Coastguard Worker # module m { 222*da0073e9SAndroid Build Coastguard Worker # attributes { 223*da0073e9SAndroid Build Coastguard Worker # } 224*da0073e9SAndroid Build Coastguard Worker # ... 225*da0073e9SAndroid Build Coastguard Worker # submodule { 226*da0073e9SAndroid Build Coastguard Worker # } 227*da0073e9SAndroid Build Coastguard Worker # } 228*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 229*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("b")) 230*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_nested_fork(self): 234*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 235*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 236*da0073e9SAndroid Build Coastguard Worker super().__init__() 237*da0073e9SAndroid Build Coastguard Worker self.a = torch.ones(20, 20) 238*da0073e9SAndroid Build Coastguard Worker self.b = torch.ones(20, 20) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 241*da0073e9SAndroid Build Coastguard Worker return self.a * self.b + x 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker class SubModule2(nn.Module): 244*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 245*da0073e9SAndroid Build Coastguard Worker super().__init__() 246*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 247*da0073e9SAndroid Build Coastguard Worker self.c = torch.ones(20, 20) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 250*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(self.sub.forward, x) 251*da0073e9SAndroid Build Coastguard Worker y_hat = self.sub(x) 252*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 253*da0073e9SAndroid Build Coastguard Worker return y_hat + y + self.c 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 256*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 257*da0073e9SAndroid Build Coastguard Worker super().__init__() 258*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule2() 259*da0073e9SAndroid Build Coastguard Worker self.d = 1 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 262*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(self.sub.forward, x) 263*da0073e9SAndroid Build Coastguard Worker y_hat = self.sub(x) 264*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 265*da0073e9SAndroid Build Coastguard Worker self.d = 2 266*da0073e9SAndroid Build Coastguard Worker return y_hat * y + self.d 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 269*da0073e9SAndroid Build Coastguard Worker m.eval() 270*da0073e9SAndroid Build Coastguard Worker input = torch.randn(20, 20) 271*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 272*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 273*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 274*da0073e9SAndroid Build Coastguard Worker # module m { 275*da0073e9SAndroid Build Coastguard Worker # attributes { 276*da0073e9SAndroid Build Coastguard Worker # } 277*da0073e9SAndroid Build Coastguard Worker # ... 278*da0073e9SAndroid Build Coastguard Worker # submodule { 279*da0073e9SAndroid Build Coastguard Worker # } 280*da0073e9SAndroid Build Coastguard Worker # } 281*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 282*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("b")) 283*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("c")) 284*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("d")) 285*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_fork2(self): 289*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 290*da0073e9SAndroid Build Coastguard Worker def foo(x): 291*da0073e9SAndroid Build Coastguard Worker return x * 2 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 294*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 295*da0073e9SAndroid Build Coastguard Worker super().__init__() 296*da0073e9SAndroid Build Coastguard Worker self.a = torch.ones(20, 20) 297*da0073e9SAndroid Build Coastguard Worker self.b = torch.ones(20, 20) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 300*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, self.a) 301*da0073e9SAndroid Build Coastguard Worker y_hat = foo(self.b) 302*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 303*da0073e9SAndroid Build Coastguard Worker return y_hat + y 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 306*da0073e9SAndroid Build Coastguard Worker m.eval() 307*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 308*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 309*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 312*da0073e9SAndroid Build Coastguard Worker # module m { 313*da0073e9SAndroid Build Coastguard Worker # attributes { 314*da0073e9SAndroid Build Coastguard Worker # self.a = ... 315*da0073e9SAndroid Build Coastguard Worker # self.b = .. 316*da0073e9SAndroid Build Coastguard Worker # } 317*da0073e9SAndroid Build Coastguard Worker # ... 318*da0073e9SAndroid Build Coastguard Worker # submodule { 319*da0073e9SAndroid Build Coastguard Worker # } 320*da0073e9SAndroid Build Coastguard Worker # } 321*da0073e9SAndroid Build Coastguard Worker # TODO: Although there are no mutation, the alias analysis 322*da0073e9SAndroid Build Coastguard Worker # conservatively assumes there is a mutation because attributes are 323*da0073e9SAndroid Build Coastguard Worker # passed to fork subgraph. both 'a' and 'b' are preserved. 324*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("a")) 325*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("b")) 326*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_fork_calling_module_method(self): 330*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 331*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 332*da0073e9SAndroid Build Coastguard Worker return x * y 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 335*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 336*da0073e9SAndroid Build Coastguard Worker super().__init__() 337*da0073e9SAndroid Build Coastguard Worker self.a = torch.ones(20, 20) 338*da0073e9SAndroid Build Coastguard Worker self.b = torch.ones(20, 20) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 341*da0073e9SAndroid Build Coastguard Worker def foo(self, x): 342*da0073e9SAndroid Build Coastguard Worker return x * self.a 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 345*da0073e9SAndroid Build Coastguard Worker def bar(self, x): 346*da0073e9SAndroid Build Coastguard Worker return x * self.b 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 349*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(self.foo, self.b) 350*da0073e9SAndroid Build Coastguard Worker y_hat = self.bar(self.a) 351*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 352*da0073e9SAndroid Build Coastguard Worker return y_hat + y 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 355*da0073e9SAndroid Build Coastguard Worker m.eval() 356*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 357*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 358*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 359*da0073e9SAndroid Build Coastguard Worker # Check if frozen module looks as below: 360*da0073e9SAndroid Build Coastguard Worker # module m { 361*da0073e9SAndroid Build Coastguard Worker # attributes { 362*da0073e9SAndroid Build Coastguard Worker # self.b = .. 363*da0073e9SAndroid Build Coastguard Worker # } 364*da0073e9SAndroid Build Coastguard Worker # ... 365*da0073e9SAndroid Build Coastguard Worker # TODO: Although there are no mutation, the alias analysis 366*da0073e9SAndroid Build Coastguard Worker # conservatively assumes there is a mutation because attributes are 367*da0073e9SAndroid Build Coastguard Worker # passed to fork subgraph. 'b' is preserved. 368*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 369*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("b")) 370*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_sharedclasstype(self): 374*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 375*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 376*da0073e9SAndroid Build Coastguard Worker super().__init__() 377*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 378*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([2.2]) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 381*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 384*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 385*da0073e9SAndroid Build Coastguard Worker self.a[0] += 10 386*da0073e9SAndroid Build Coastguard Worker return self.b 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 389*da0073e9SAndroid Build Coastguard Worker def modify_b(self, x): 390*da0073e9SAndroid Build Coastguard Worker self.b[0] += 20 391*da0073e9SAndroid Build Coastguard Worker return self.a 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker class SubModule2(nn.Module): 394*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 395*da0073e9SAndroid Build Coastguard Worker super().__init__() 396*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 397*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([3.3]) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 400*da0073e9SAndroid Build Coastguard Worker y = self.sub.modify_b(x) 401*da0073e9SAndroid Build Coastguard Worker return y + self.b 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 404*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 405*da0073e9SAndroid Build Coastguard Worker super().__init__() 406*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() # sub1 and sub2.sub shared same class type. 407*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule2() 408*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([4.4]) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 411*da0073e9SAndroid Build Coastguard Worker z = self.sub1.modify_a(x) 412*da0073e9SAndroid Build Coastguard Worker return self.sub2(x) + z + self.a 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 415*da0073e9SAndroid Build Coastguard Worker m.eval() 416*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 417*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 418*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker # Checking if Frozen module looks as below 421*da0073e9SAndroid Build Coastguard Worker # module mf { 422*da0073e9SAndroid Build Coastguard Worker # attributes { 423*da0073e9SAndroid Build Coastguard Worker # sub1 = ... 424*da0073e9SAndroid Build Coastguard Worker # sub2 = ... 425*da0073e9SAndroid Build Coastguard Worker # } 426*da0073e9SAndroid Build Coastguard Worker # ... 427*da0073e9SAndroid Build Coastguard Worker # submodules { 428*da0073e9SAndroid Build Coastguard Worker # module sub1 { 429*da0073e9SAndroid Build Coastguard Worker # attributes { 430*da0073e9SAndroid Build Coastguard Worker # a = ... 431*da0073e9SAndroid Build Coastguard Worker # b = ... 432*da0073e9SAndroid Build Coastguard Worker # } 433*da0073e9SAndroid Build Coastguard Worker # ... 434*da0073e9SAndroid Build Coastguard Worker # } 435*da0073e9SAndroid Build Coastguard Worker # module sub2 { 436*da0073e9SAndroid Build Coastguard Worker # attributes { 437*da0073e9SAndroid Build Coastguard Worker # sub = ... 438*da0073e9SAndroid Build Coastguard Worker # } 439*da0073e9SAndroid Build Coastguard Worker # ... 440*da0073e9SAndroid Build Coastguard Worker # submodule { 441*da0073e9SAndroid Build Coastguard Worker # module sub { 442*da0073e9SAndroid Build Coastguard Worker # attributes { 443*da0073e9SAndroid Build Coastguard Worker # a = ... 444*da0073e9SAndroid Build Coastguard Worker # b = ... 445*da0073e9SAndroid Build Coastguard Worker # } 446*da0073e9SAndroid Build Coastguard Worker # ... 447*da0073e9SAndroid Build Coastguard Worker # } 448*da0073e9SAndroid Build Coastguard Worker # } 449*da0073e9SAndroid Build Coastguard Worker # } 450*da0073e9SAndroid Build Coastguard Worker # } 451*da0073e9SAndroid Build Coastguard Worker # } 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub1")) 454*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("a")) 455*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("b")) 456*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub2")) 458*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.hasattr("sub")) 459*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.sub2.hasattr("b")) 460*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.sub.hasattr("a")) 461*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.sub.hasattr("b")) 462*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_nestedaliasing(self): 466*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 467*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 468*da0073e9SAndroid Build Coastguard Worker super().__init__() 469*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 470*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([2.2]) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 473*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 476*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 477*da0073e9SAndroid Build Coastguard Worker self.a[0] = 10 478*da0073e9SAndroid Build Coastguard Worker return self.b 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 481*da0073e9SAndroid Build Coastguard Worker def modify_b(self, x): 482*da0073e9SAndroid Build Coastguard Worker self.b[0] = 20 483*da0073e9SAndroid Build Coastguard Worker return self.a 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker Sub = SubModule() 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker class SubModule2(nn.Module): 488*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 489*da0073e9SAndroid Build Coastguard Worker super().__init__() 490*da0073e9SAndroid Build Coastguard Worker self.sub = Sub # aliasing 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 493*da0073e9SAndroid Build Coastguard Worker return self.sub.a 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 496*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 497*da0073e9SAndroid Build Coastguard Worker super().__init__() 498*da0073e9SAndroid Build Coastguard Worker self.sub1 = Sub # aliasing 499*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule2() 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 502*da0073e9SAndroid Build Coastguard Worker z = self.sub1.modify_a(x) 503*da0073e9SAndroid Build Coastguard Worker return self.sub2(x) + z 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 506*da0073e9SAndroid Build Coastguard Worker m.eval() 507*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 508*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub1")) 509*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("a")) 510*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.sub1.hasattr("b")) 511*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub2")) 512*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.hasattr("sub")) 513*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 514*da0073e9SAndroid Build Coastguard Worker mf.sub2.sub.hasattr("a") 515*da0073e9SAndroid Build Coastguard Worker ) # Freezing detects that self.sub2.sub.a and self.sub1.a are alias 516*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.sub2.sub.hasattr("b")) 517*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 518*da0073e9SAndroid Build Coastguard Worker output_s = m.forward(input) 519*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result 523*da0073e9SAndroid Build Coastguard Worker # Eager and Script modules produce different output. 524*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_nestedaliasingscalar(self): 525*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 526*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 527*da0073e9SAndroid Build Coastguard Worker super().__init__() 528*da0073e9SAndroid Build Coastguard Worker self.a = 1.1 529*da0073e9SAndroid Build Coastguard Worker self.b = 2.2 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 532*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 535*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 536*da0073e9SAndroid Build Coastguard Worker self.a = 10.0 537*da0073e9SAndroid Build Coastguard Worker return self.b 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 540*da0073e9SAndroid Build Coastguard Worker def modify_b(self, x): 541*da0073e9SAndroid Build Coastguard Worker self.b = 20.0 542*da0073e9SAndroid Build Coastguard Worker return self.a 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker Sub = SubModule() 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker class SubModule2(nn.Module): 547*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 548*da0073e9SAndroid Build Coastguard Worker super().__init__() 549*da0073e9SAndroid Build Coastguard Worker self.sub = Sub # aliasing 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 552*da0073e9SAndroid Build Coastguard Worker return self.sub.a 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 555*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 556*da0073e9SAndroid Build Coastguard Worker super().__init__() 557*da0073e9SAndroid Build Coastguard Worker self.sub1 = Sub # aliasing 558*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule2() 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 561*da0073e9SAndroid Build Coastguard Worker z = self.sub1.modify_a(x) 562*da0073e9SAndroid Build Coastguard Worker return self.sub2(x) + z 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker m = TestModule() 565*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 566*da0073e9SAndroid Build Coastguard Worker ms.eval() 567*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(ms._c) 568*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub1")) 569*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("a")) 570*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.sub1.hasattr("b")) 571*da0073e9SAndroid Build Coastguard Worker # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug) 572*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("sub2")) 573*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 574*da0073e9SAndroid Build Coastguard Worker output = m.forward(input) 575*da0073e9SAndroid Build Coastguard Worker output_s = ms.forward(input) 576*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 577*da0073e9SAndroid Build Coastguard Worker # Should be equal 578*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output, output_s) 579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_preserve_sub_module(self): 582*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 583*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 584*da0073e9SAndroid Build Coastguard Worker super().__init__() 585*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 586*da0073e9SAndroid Build Coastguard Worker self.b = 2.2 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 589*da0073e9SAndroid Build Coastguard Worker return self.a 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 592*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 593*da0073e9SAndroid Build Coastguard Worker super().__init__() 594*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() # aliasing 595*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule() 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 598*da0073e9SAndroid Build Coastguard Worker return self.sub2(x) + self.sub1(x) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker m = TestModule() 601*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 602*da0073e9SAndroid Build Coastguard Worker ms.eval() 603*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(ms._c, ["sub1"]) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker # Test that 'sub1' is preserved entirely and 'sub2' is completely folded 606*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub1")) 607*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("a")) 608*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("b")) 609*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("sub2")) 610*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 611*da0073e9SAndroid Build Coastguard Worker output_s = ms.forward(input) 612*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_preserve_sub_module_and_mutation(self): 616*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 617*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 618*da0073e9SAndroid Build Coastguard Worker super().__init__() 619*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 620*da0073e9SAndroid Build Coastguard Worker self.b = 2.2 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 623*da0073e9SAndroid Build Coastguard Worker self.a[0] = 3.3 624*da0073e9SAndroid Build Coastguard Worker return self.a 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 627*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 628*da0073e9SAndroid Build Coastguard Worker super().__init__() 629*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() # aliasing 630*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule() 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 633*da0073e9SAndroid Build Coastguard Worker return self.sub2(x) + self.sub1(x) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker m = TestModule() 636*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 637*da0073e9SAndroid Build Coastguard Worker ms.eval() 638*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(ms._c, ["sub1"]) 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker # Test that be both sub1 and sub1 are preserved and 'b' is preserved 641*da0073e9SAndroid Build Coastguard Worker # even if it is not used. To fulfill user request to preserve 'sub1' 642*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub1")) 643*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("a")) 644*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub1.hasattr("b")) 645*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("sub2")) 646*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.hasattr("a")) 647*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.sub2.hasattr("b")) 648*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 649*da0073e9SAndroid Build Coastguard Worker output_s = ms.forward(input) 650*da0073e9SAndroid Build Coastguard Worker output_f = mf.forward(input) 651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_helperfunction(self): 654*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 655*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 656*da0073e9SAndroid Build Coastguard Worker super().__init__() 657*da0073e9SAndroid Build Coastguard Worker self.a = 11 658*da0073e9SAndroid Build Coastguard Worker self.b = 2 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 661*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 664*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 665*da0073e9SAndroid Build Coastguard Worker super().__init__() 666*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 667*da0073e9SAndroid Build Coastguard Worker self.a = 3 668*da0073e9SAndroid Build Coastguard Worker self.b = 4 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 671*da0073e9SAndroid Build Coastguard Worker self.b = 20 672*da0073e9SAndroid Build Coastguard Worker return self._forward(x) + self.a + self.b 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker def _forward(self, x): 675*da0073e9SAndroid Build Coastguard Worker return self.sub(x) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 678*da0073e9SAndroid Build Coastguard Worker m.eval() 679*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 680*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 681*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("sub")) 682*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mf.hasattr("a")) 683*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mf.hasattr("b")) 684*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 685*da0073e9SAndroid Build Coastguard Worker AttributeError, "TestModule (.*) does not have a field with name '_forward'" 686*da0073e9SAndroid Build Coastguard Worker ): 687*da0073e9SAndroid Build Coastguard Worker mf._forward(x) # noqa: F821 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_inplace_mutable(self): 690*da0073e9SAndroid Build Coastguard Worker class FreezeMe(torch.jit.ScriptModule): 691*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 692*da0073e9SAndroid Build Coastguard Worker super().__init__() 693*da0073e9SAndroid Build Coastguard Worker self.a = [11, 22] 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 696*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 697*da0073e9SAndroid Build Coastguard Worker for i in range(3): 698*da0073e9SAndroid Build Coastguard Worker self.a.append(i) 699*da0073e9SAndroid Build Coastguard Worker return self.a 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 702*da0073e9SAndroid Build Coastguard Worker m.eval() 703*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m._c) 704*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 705*da0073e9SAndroid Build Coastguard Worker m.forward(torch.tensor([3])) 706*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(torch.tensor([5])) 707*da0073e9SAndroid Build Coastguard Worker expected = [11, 22, 0, 1, 2, 0, 1, 2] 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker # Mutable attributes 711*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_mutable_list(self): 712*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 713*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 714*da0073e9SAndroid Build Coastguard Worker super().__init__() 715*da0073e9SAndroid Build Coastguard Worker self.a = [1, 2] 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 718*da0073e9SAndroid Build Coastguard Worker return self.a 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 721*da0073e9SAndroid Build Coastguard Worker m.eval() 722*da0073e9SAndroid Build Coastguard Worker m.a.append(3) 723*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 724*da0073e9SAndroid Build Coastguard Worker v = m_s.a 725*da0073e9SAndroid Build Coastguard Worker v.append(4) 726*da0073e9SAndroid Build Coastguard Worker m_s.a = v 727*da0073e9SAndroid Build Coastguard Worker m_s.eval() 728*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 729*da0073e9SAndroid Build Coastguard Worker # Post-freezing mutating m_s.a does not affect m_f (m_f has its own copy). 730*da0073e9SAndroid Build Coastguard Worker v = m_s.a 731*da0073e9SAndroid Build Coastguard Worker v.append(5) 732*da0073e9SAndroid Build Coastguard Worker m_s.a = v 733*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 734*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(torch.tensor([5])) 735*da0073e9SAndroid Build Coastguard Worker expected = [1, 2, 3, 4] 736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_mutable_dict(self): 739*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 740*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 741*da0073e9SAndroid Build Coastguard Worker super().__init__() 742*da0073e9SAndroid Build Coastguard Worker self.a = {"layer": "4"} 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 745*da0073e9SAndroid Build Coastguard Worker return self.a 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 748*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 749*da0073e9SAndroid Build Coastguard Worker self.a["layer"] = self.a["layer"] + "1" 750*da0073e9SAndroid Build Coastguard Worker return self.a 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 753*da0073e9SAndroid Build Coastguard Worker m.eval() 754*da0073e9SAndroid Build Coastguard Worker m.a["layer2"] = "3" 755*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 756*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5) 757*da0073e9SAndroid Build Coastguard Worker m_s.modify_a(t) 758*da0073e9SAndroid Build Coastguard Worker m_s.eval() 759*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 760*da0073e9SAndroid Build Coastguard Worker m.a["layer2"] += "2" 761*da0073e9SAndroid Build Coastguard Worker m_s.modify_a(t) 762*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 763*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(t) 764*da0073e9SAndroid Build Coastguard Worker expected = {"layer": "411", "layer2": "3"} 765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_mutable_tensor(self): 768*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 769*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 770*da0073e9SAndroid Build Coastguard Worker super().__init__() 771*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.0, 2.0, 3.0]) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 774*da0073e9SAndroid Build Coastguard Worker return self.a 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 777*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 778*da0073e9SAndroid Build Coastguard Worker m_s.a[1] += 3.0 779*da0073e9SAndroid Build Coastguard Worker m_s.eval() 780*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 781*da0073e9SAndroid Build Coastguard Worker # Post-freezing tensor attribute mutations affect m_f. 782*da0073e9SAndroid Build Coastguard Worker # FIXME: deep copy all folded attributes so that m_f has full ownership. 783*da0073e9SAndroid Build Coastguard Worker m_s.a[0] += 5.0 784*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 785*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(torch.tensor([5])) 786*da0073e9SAndroid Build Coastguard Worker expected = [6.0, 5.0, 3.0] 787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_tuple(self): 790*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 791*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 792*da0073e9SAndroid Build Coastguard Worker super().__init__() 793*da0073e9SAndroid Build Coastguard Worker self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi") 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 796*da0073e9SAndroid Build Coastguard Worker if x[0] == 2.0: 797*da0073e9SAndroid Build Coastguard Worker self.a[0][0] = 10 798*da0073e9SAndroid Build Coastguard Worker return self.a[0].sum() 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 801*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 802*da0073e9SAndroid Build Coastguard Worker m_s.eval() 803*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([2.0]) 804*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 805*da0073e9SAndroid Build Coastguard Worker m_s.a[0][0] = 1 806*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 807*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 808*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_tensor(self): 812*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 813*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 814*da0073e9SAndroid Build Coastguard Worker super().__init__() 815*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 818*da0073e9SAndroid Build Coastguard Worker x = self.a.view(2, 3) 819*da0073e9SAndroid Build Coastguard Worker x[0][0] += 10 820*da0073e9SAndroid Build Coastguard Worker return self.a.sum() 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 823*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 824*da0073e9SAndroid Build Coastguard Worker m_s.eval() 825*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 826*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 827*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 828*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 829*da0073e9SAndroid Build Coastguard Worker m_f.a[0] -= 10 830*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_list(self): 834*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 835*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 836*da0073e9SAndroid Build Coastguard Worker super().__init__() 837*da0073e9SAndroid Build Coastguard Worker self.a = [torch.tensor([1, 2, 3, 4, 5, 6])] 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 840*da0073e9SAndroid Build Coastguard Worker self.a[0][1] += 10 841*da0073e9SAndroid Build Coastguard Worker return self.a[0].sum() 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 844*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 845*da0073e9SAndroid Build Coastguard Worker m_s.eval() 846*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 847*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 848*da0073e9SAndroid Build Coastguard Worker m_s.a[0][1] -= 10 849*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 850*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 851*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_tensor_attr(self): 855*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 856*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 857*da0073e9SAndroid Build Coastguard Worker super().__init__() 858*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 859*da0073e9SAndroid Build Coastguard Worker self.b = self.a.view(2, 3) 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 862*da0073e9SAndroid Build Coastguard Worker self.b[1] += 10 863*da0073e9SAndroid Build Coastguard Worker return self.a.sum() 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 866*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 867*da0073e9SAndroid Build Coastguard Worker m_s.eval() 868*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 869*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 870*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 871*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 872*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(51) # 1+2+3+14+15+16 873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_tensor_attr2(self): 876*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 877*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 878*da0073e9SAndroid Build Coastguard Worker super().__init__() 879*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 880*da0073e9SAndroid Build Coastguard Worker self.b = {"layer": ([self.a.view(2, 3), torch.tensor([10])], 20)} 881*da0073e9SAndroid Build Coastguard Worker self.c = ([self.a.view(2, 3), torch.tensor([10])], 20) 882*da0073e9SAndroid Build Coastguard Worker self.d = (self.a.view(2, 3), 20) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 885*da0073e9SAndroid Build Coastguard Worker self.d[0][0] += 10 886*da0073e9SAndroid Build Coastguard Worker return self.a.sum() 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 889*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 890*da0073e9SAndroid Build Coastguard Worker m_s.eval() 891*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 892*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 893*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 894*da0073e9SAndroid Build Coastguard Worker RuntimeError, "module contains attributes values that overlaps" 895*da0073e9SAndroid Build Coastguard Worker ): 896*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_tensor_attr3(self): 899*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 900*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 901*da0073e9SAndroid Build Coastguard Worker super().__init__() 902*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 903*da0073e9SAndroid Build Coastguard Worker self.b = [self.a, torch.tensor([10])] 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 906*da0073e9SAndroid Build Coastguard Worker self.a[1] += 10 907*da0073e9SAndroid Build Coastguard Worker return self.b[0].sum() 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 910*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 911*da0073e9SAndroid Build Coastguard Worker m_s.eval() 912*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 913*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 914*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 915*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 916*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("b")) 917*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 918*da0073e9SAndroid Build Coastguard Worker expected += 10 # account for self.a += 10. 919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_tensor_attr4(self): 922*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 923*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 924*da0073e9SAndroid Build Coastguard Worker super().__init__() 925*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1, 2, 3, 4, 5, 6]) 926*da0073e9SAndroid Build Coastguard Worker self.b = [self.a, torch.tensor([10])] 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 929*da0073e9SAndroid Build Coastguard Worker self.b[0][0] += 10 930*da0073e9SAndroid Build Coastguard Worker return self.a.sum() 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 933*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 934*da0073e9SAndroid Build Coastguard Worker m_s.eval() 935*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 936*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 937*da0073e9SAndroid Build Coastguard Worker m_s.a[0] -= 10 938*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 939*da0073e9SAndroid Build Coastguard Worker RuntimeError, "module contains attributes values that overlaps" 940*da0073e9SAndroid Build Coastguard Worker ): 941*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_overlapping_attrs(self): 944*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1, 2, 3, 4, 5, 6]) 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 947*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 948*da0073e9SAndroid Build Coastguard Worker super().__init__() 949*da0073e9SAndroid Build Coastguard Worker self.b = [a.view(3, 2), torch.tensor([10])] 950*da0073e9SAndroid Build Coastguard Worker self.c = (20, a.view(2, 3)) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 953*da0073e9SAndroid Build Coastguard Worker self.b[0][0] += 10 954*da0073e9SAndroid Build Coastguard Worker return self.c[1].sum() 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 957*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 958*da0073e9SAndroid Build Coastguard Worker m_s.eval() 959*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 960*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 961*da0073e9SAndroid Build Coastguard Worker a[0] -= 10 962*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 963*da0073e9SAndroid Build Coastguard Worker RuntimeError, "module contains attributes values that overlaps" 964*da0073e9SAndroid Build Coastguard Worker ): 965*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_attr(self): 968*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 969*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 970*da0073e9SAndroid Build Coastguard Worker super().__init__() 971*da0073e9SAndroid Build Coastguard Worker self.a = [1, 2, 3, 4, 5, 6] 972*da0073e9SAndroid Build Coastguard Worker self.b = self.a 973*da0073e9SAndroid Build Coastguard Worker self.c = (self.a, 10) 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 976*da0073e9SAndroid Build Coastguard Worker self.b[1] += 10 977*da0073e9SAndroid Build Coastguard Worker return str(self.a) + str(self.c) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 980*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 981*da0073e9SAndroid Build Coastguard Worker m_s.eval() 982*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 983*da0073e9SAndroid Build Coastguard Worker # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034) 984*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("a")) 985*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m_f.hasattr("c")) 986*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 987*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 988*da0073e9SAndroid Build Coastguard Worker expected = m_s.forward(inp) 989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker # Check attribute a is preserved. Alias analysis detects that 'a' has output writers. 992*da0073e9SAndroid Build Coastguard Worker # In this example, 'a' is not mutated. However, we do not track which sub 993*da0073e9SAndroid Build Coastguard Worker # values of a composite ivalue is mutated. 994*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_attr2(self): 995*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 996*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 997*da0073e9SAndroid Build Coastguard Worker super().__init__() 998*da0073e9SAndroid Build Coastguard Worker self.a = [1, 2, 3, 4, 5, 6] 999*da0073e9SAndroid Build Coastguard Worker self.b = ([11], [10]) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1002*da0073e9SAndroid Build Coastguard Worker v = self.a 1003*da0073e9SAndroid Build Coastguard Worker self.b = (v, [12]) 1004*da0073e9SAndroid Build Coastguard Worker v2 = self.b[1] 1005*da0073e9SAndroid Build Coastguard Worker v2.append(7) 1006*da0073e9SAndroid Build Coastguard Worker return str(v) + str(v2) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1009*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1010*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1011*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 1012*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 1013*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 1014*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 1015*da0073e9SAndroid Build Coastguard Worker expected = m.forward(inp) 1016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_aliased_attr3(self): 1019*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 1020*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1021*da0073e9SAndroid Build Coastguard Worker super().__init__() 1022*da0073e9SAndroid Build Coastguard Worker self.a = [1, 2, 3, 4, 5, 6] 1023*da0073e9SAndroid Build Coastguard Worker self.b = ([11], [10]) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1026*da0073e9SAndroid Build Coastguard Worker v = self.a 1027*da0073e9SAndroid Build Coastguard Worker v2 = (v, [12]) 1028*da0073e9SAndroid Build Coastguard Worker v3 = v2[0] 1029*da0073e9SAndroid Build Coastguard Worker v3.append(7) 1030*da0073e9SAndroid Build Coastguard Worker return str(self.a) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1033*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1034*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1035*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 1036*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("a")) 1037*da0073e9SAndroid Build Coastguard Worker inp = torch.tensor([5]) 1038*da0073e9SAndroid Build Coastguard Worker out = m_f.forward(inp) 1039*da0073e9SAndroid Build Coastguard Worker expected = m.forward(inp) 1040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_return_self(self): 1043*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 1044*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1045*da0073e9SAndroid Build Coastguard Worker super().__init__() 1046*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.0, 2.0, 3.0]) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1049*da0073e9SAndroid Build Coastguard Worker return self 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1052*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1053*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1054*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1055*da0073e9SAndroid Build Coastguard Worker RuntimeError, "attempted to freeze a module that return itself" 1056*da0073e9SAndroid Build Coastguard Worker ): 1057*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_inlining(self): 1060*da0073e9SAndroid Build Coastguard Worker @torch.jit.script # noqa: B903 1061*da0073e9SAndroid Build Coastguard Worker class Obj: # noqa: B903 1062*da0073e9SAndroid Build Coastguard Worker def __init__(self, x: int, y: int): 1063*da0073e9SAndroid Build Coastguard Worker self.x = x 1064*da0073e9SAndroid Build Coastguard Worker self.y = y 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 1067*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1068*da0073e9SAndroid Build Coastguard Worker super().__init__() 1069*da0073e9SAndroid Build Coastguard Worker self.obj = Obj(2, 3) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker def forward(self, i: int): 1072*da0073e9SAndroid Build Coastguard Worker print(self.obj) 1073*da0073e9SAndroid Build Coastguard Worker return i 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(Mod().eval())) 1076*da0073e9SAndroid Build Coastguard Worker obj = mod.graph.findNode("prim::Constant") 1077*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._jit_object_is_non_holding(obj)) 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 1080*da0073e9SAndroid Build Coastguard Worker torch.jit.save(mod, buffer) 1081*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(buffer) 1084*da0073e9SAndroid Build Coastguard Worker obj = mod.graph.findNode("prim::Constant") 1085*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._jit_object_is_non_holding(obj)) 1086*da0073e9SAndroid Build Coastguard Worker 1087*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_return_sub_module(self): 1088*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 1089*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1090*da0073e9SAndroid Build Coastguard Worker super().__init__() 1091*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 32, 3, 1) 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1094*da0073e9SAndroid Build Coastguard Worker return self.conv1 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1097*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1098*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1099*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c) 1100*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_f.hasattr("conv1")) 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_no_forward(self): 1103*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 1104*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1105*da0073e9SAndroid Build Coastguard Worker super().__init__() 1106*da0073e9SAndroid Build Coastguard Worker self.lin = nn.Linear(10, 1) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1109*da0073e9SAndroid Build Coastguard Worker def foo(self, x): 1110*da0073e9SAndroid Build Coastguard Worker return self.lin(x) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1113*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1114*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1115*da0073e9SAndroid Build Coastguard Worker m_f = torch._C._freeze_module(m_s._c, preservedAttrs=["foo"]) 1116*da0073e9SAndroid Build Coastguard Worker input = torch.ones(10) 1117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_s.foo(input), m_f.foo(input)) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker def test_freeze_no_forward(self): 1120*da0073e9SAndroid Build Coastguard Worker class FreezeMe(nn.Module): 1121*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1122*da0073e9SAndroid Build Coastguard Worker super().__init__() 1123*da0073e9SAndroid Build Coastguard Worker self.lin = nn.Linear(10, 1) 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1126*da0073e9SAndroid Build Coastguard Worker def foo(self, x): 1127*da0073e9SAndroid Build Coastguard Worker return self.lin(x) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker m = FreezeMe() 1130*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1131*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1132*da0073e9SAndroid Build Coastguard Worker m_f = torch.jit.freeze(m_s, preserved_attrs=["foo"]) 1133*da0073e9SAndroid Build Coastguard Worker input = torch.ones(10) 1134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_s.foo(input), m_f.foo(input)) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_in_training_mode(self): 1137*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 1138*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1139*da0073e9SAndroid Build Coastguard Worker super().__init__() 1140*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 32, 3, 1) 1141*da0073e9SAndroid Build Coastguard Worker self.conv2 = nn.Conv2d(32, 64, 3, 1) 1142*da0073e9SAndroid Build Coastguard Worker self.dropout1 = nn.Dropout2d(0.25) 1143*da0073e9SAndroid Build Coastguard Worker self.dropout2 = nn.Dropout2d(0.5) 1144*da0073e9SAndroid Build Coastguard Worker self.fc1 = nn.Linear(9216, 128) 1145*da0073e9SAndroid Build Coastguard Worker self.fc2 = nn.Linear(128, 10) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1148*da0073e9SAndroid Build Coastguard Worker x = self.conv1(x) 1149*da0073e9SAndroid Build Coastguard Worker x = nn.functional.relu(x) 1150*da0073e9SAndroid Build Coastguard Worker x = self.conv2(x) 1151*da0073e9SAndroid Build Coastguard Worker x = nn.functional.max_pool2d(x, 2) 1152*da0073e9SAndroid Build Coastguard Worker x = self.dropout1(x) 1153*da0073e9SAndroid Build Coastguard Worker x = torch.flatten(x, 1) 1154*da0073e9SAndroid Build Coastguard Worker x = self.fc1(x) 1155*da0073e9SAndroid Build Coastguard Worker x = nn.functional.relu(x) 1156*da0073e9SAndroid Build Coastguard Worker x = self.dropout2(x) 1157*da0073e9SAndroid Build Coastguard Worker x = self.fc2(x) 1158*da0073e9SAndroid Build Coastguard Worker output = nn.functional.log_softmax(x, dim=1) 1159*da0073e9SAndroid Build Coastguard Worker return output 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(Net()) 1162*da0073e9SAndroid Build Coastguard Worker model.train() 1163*da0073e9SAndroid Build Coastguard Worker mTrain_freezed = torch._C._freeze_module(model._c) 1164*da0073e9SAndroid Build Coastguard Worker # verify mTrain_freezed looks exactly as: 1165*da0073e9SAndroid Build Coastguard Worker # module { 1166*da0073e9SAndroid Build Coastguard Worker # attributes { 1167*da0073e9SAndroid Build Coastguard Worker # conv1 = ... 1168*da0073e9SAndroid Build Coastguard Worker # conv2 = ... 1169*da0073e9SAndroid Build Coastguard Worker # dropout1 = ... 1170*da0073e9SAndroid Build Coastguard Worker # dropout2 = ... 1171*da0073e9SAndroid Build Coastguard Worker # fc1 = ... 1172*da0073e9SAndroid Build Coastguard Worker # fc2 = ... 1173*da0073e9SAndroid Build Coastguard Worker # } 1174*da0073e9SAndroid Build Coastguard Worker # ... 1175*da0073e9SAndroid Build Coastguard Worker # submodules { 1176*da0073e9SAndroid Build Coastguard Worker # module conv1 { 1177*da0073e9SAndroid Build Coastguard Worker # attributes { 1178*da0073e9SAndroid Build Coastguard Worker # weight = ... 1179*da0073e9SAndroid Build Coastguard Worker # bias = ... 1180*da0073e9SAndroid Build Coastguard Worker # } 1181*da0073e9SAndroid Build Coastguard Worker # ... 1182*da0073e9SAndroid Build Coastguard Worker # } 1183*da0073e9SAndroid Build Coastguard Worker # module conv2 { 1184*da0073e9SAndroid Build Coastguard Worker # attributes { 1185*da0073e9SAndroid Build Coastguard Worker # weight = ... 1186*da0073e9SAndroid Build Coastguard Worker # bias = ... 1187*da0073e9SAndroid Build Coastguard Worker # } 1188*da0073e9SAndroid Build Coastguard Worker # ... 1189*da0073e9SAndroid Build Coastguard Worker # } 1190*da0073e9SAndroid Build Coastguard Worker # module dropout1 { 1191*da0073e9SAndroid Build Coastguard Worker # attributes { 1192*da0073e9SAndroid Build Coastguard Worker # training = ... 1193*da0073e9SAndroid Build Coastguard Worker # } 1194*da0073e9SAndroid Build Coastguard Worker # ... 1195*da0073e9SAndroid Build Coastguard Worker # } 1196*da0073e9SAndroid Build Coastguard Worker # module dropout2 { 1197*da0073e9SAndroid Build Coastguard Worker # attributes { 1198*da0073e9SAndroid Build Coastguard Worker # training = ... 1199*da0073e9SAndroid Build Coastguard Worker # } 1200*da0073e9SAndroid Build Coastguard Worker # ... 1201*da0073e9SAndroid Build Coastguard Worker # } 1202*da0073e9SAndroid Build Coastguard Worker # module fc1 { 1203*da0073e9SAndroid Build Coastguard Worker # attributes { 1204*da0073e9SAndroid Build Coastguard Worker # weight = ... 1205*da0073e9SAndroid Build Coastguard Worker # bias = ... 1206*da0073e9SAndroid Build Coastguard Worker # } 1207*da0073e9SAndroid Build Coastguard Worker # ... 1208*da0073e9SAndroid Build Coastguard Worker # } 1209*da0073e9SAndroid Build Coastguard Worker # module fc2 { 1210*da0073e9SAndroid Build Coastguard Worker # attributes { 1211*da0073e9SAndroid Build Coastguard Worker # weight = ... 1212*da0073e9SAndroid Build Coastguard Worker # bias = ... 1213*da0073e9SAndroid Build Coastguard Worker # } 1214*da0073e9SAndroid Build Coastguard Worker # ... 1215*da0073e9SAndroid Build Coastguard Worker # } 1216*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mTrain_freezed.hasattr("training")) 1217*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("conv1")) 1218*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mTrain_freezed.conv1.hasattr("training")) 1219*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.conv1.hasattr("weight")) 1220*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.conv1.hasattr("bias")) 1221*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("conv2")) 1222*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mTrain_freezed.conv2.hasattr("training")) 1223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.conv2.hasattr("weight")) 1224*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.conv2.hasattr("bias")) 1225*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("dropout1")) 1226*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.dropout1.hasattr("training")) 1227*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("dropout2")) 1228*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.dropout2.hasattr("training")) 1229*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("fc1")) 1230*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.fc1.hasattr("weight")) 1231*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.fc1.hasattr("bias")) 1232*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.hasattr("fc2")) 1233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.fc2.hasattr("weight")) 1234*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mTrain_freezed.fc2.hasattr("bias")) 1235*da0073e9SAndroid Build Coastguard Worker model.eval() 1236*da0073e9SAndroid Build Coastguard Worker mEval_freezed = torch._C._freeze_module(model._c) 1237*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("conv1")) 1238*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("conv2")) 1239*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("dropout1")) 1240*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("training")) 1241*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("fc1")) 1242*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("dropout2")) 1243*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mEval_freezed.hasattr("fc2")) 1244*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1245*da0073e9SAndroid Build Coastguard Worker AttributeError, "does not have a field with name 'state_dict'" 1246*da0073e9SAndroid Build Coastguard Worker ): 1247*da0073e9SAndroid Build Coastguard Worker print(mEval_freezed.state_dict()) 1248*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 1249*da0073e9SAndroid Build Coastguard Worker torch.jit.save(mEval_freezed, buffer) 1250*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 1251*da0073e9SAndroid Build Coastguard Worker m = torch.jit.load(buffer) 1252*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("GetAttr[name=").run(m._c._get_method("forward").graph) 1253*da0073e9SAndroid Build Coastguard Worker m2 = torch._C._freeze_module(model._c, preserveParameters=True) 1254*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.hasattr("conv1")) 1255*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.hasattr("conv2")) 1256*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2.hasattr("dropout1")) 1257*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2.hasattr("training")) 1258*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.hasattr("fc1")) 1259*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2.hasattr("dropout2")) 1260*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.hasattr("fc2")) 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_detach_gradient(self): 1263*da0073e9SAndroid Build Coastguard Worker mod = nn.Conv2d(8, 3, 4, 2, 1) 1264*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mod.weight.requires_grad) 1265*da0073e9SAndroid Build Coastguard Worker smod = torch.jit.script(mod) 1266*da0073e9SAndroid Build Coastguard Worker smod.eval() 1267*da0073e9SAndroid Build Coastguard Worker fmod = torch._C._freeze_module(smod._c) 1268*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mod.weight.requires_grad) 1269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(smod.weight.requires_grad) 1270*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fmod.hasattr("weight")) 1271*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1, 8, 32, 32) 1272*da0073e9SAndroid Build Coastguard Worker out1 = fmod.forward(inp) 1273*da0073e9SAndroid Build Coastguard Worker # FIXME: frozen module mutated from outside (original module). 1274*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1275*da0073e9SAndroid Build Coastguard Worker smod.weight[0, 0, 0, 0] += 100.0 1276*da0073e9SAndroid Build Coastguard Worker out2 = fmod.forward(inp) 1277*da0073e9SAndroid Build Coastguard Worker out3 = smod(inp) 1278*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(out1, out2) 1279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2, out3) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_attr(self): 1282*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1283*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1284*da0073e9SAndroid Build Coastguard Worker super().__init__() 1285*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 1286*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([2.2]) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1289*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1292*da0073e9SAndroid Build Coastguard Worker m.eval() 1293*da0073e9SAndroid Build Coastguard Worker fm = torch._C._freeze_module(m._c, ["a"]) 1294*da0073e9SAndroid Build Coastguard Worker # Attribute "a" is preserved 1295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("a")) 1296*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.hasattr("b")) 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_method(self): 1299*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1300*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1301*da0073e9SAndroid Build Coastguard Worker super().__init__() 1302*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 1303*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([2.2]) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1306*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 1307*da0073e9SAndroid Build Coastguard Worker 1308*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1309*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 1310*da0073e9SAndroid Build Coastguard Worker self.a[0] += 10 1311*da0073e9SAndroid Build Coastguard Worker return self.b 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1314*da0073e9SAndroid Build Coastguard Worker def modify_b(self, x): 1315*da0073e9SAndroid Build Coastguard Worker self.b[0] += 20 1316*da0073e9SAndroid Build Coastguard Worker return self.a 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1319*da0073e9SAndroid Build Coastguard Worker m.eval() 1320*da0073e9SAndroid Build Coastguard Worker fm = torch._C._freeze_module(m._c, ["modify_a"]) 1321*da0073e9SAndroid Build Coastguard Worker # Both attribute "a" and method "modify_a" are preserved 1322*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("a")) 1323*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.hasattr("b")) 1324*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 1325*da0073e9SAndroid Build Coastguard Worker expected = m.forward(input) 1326*da0073e9SAndroid Build Coastguard Worker out = fm.forward(input) 1327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_method2(self): 1330*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1331*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1332*da0073e9SAndroid Build Coastguard Worker super().__init__() 1333*da0073e9SAndroid Build Coastguard Worker self.a = torch.tensor([1.1]) 1334*da0073e9SAndroid Build Coastguard Worker self.b = torch.tensor([2.2]) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1337*da0073e9SAndroid Build Coastguard Worker self.b += 10 1338*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 1339*da0073e9SAndroid Build Coastguard Worker 1340*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1341*da0073e9SAndroid Build Coastguard Worker def modify_a(self, x): 1342*da0073e9SAndroid Build Coastguard Worker self.a[0] += 10 1343*da0073e9SAndroid Build Coastguard Worker return self.b + self.a 1344*da0073e9SAndroid Build Coastguard Worker 1345*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1346*da0073e9SAndroid Build Coastguard Worker m.eval() 1347*da0073e9SAndroid Build Coastguard Worker fm = torch._C._freeze_module(m._c, ["modify_a"]) 1348*da0073e9SAndroid Build Coastguard Worker FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph) 1349*da0073e9SAndroid Build Coastguard Worker FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_attribute_on_submodule(self): 1352*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 1353*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1354*da0073e9SAndroid Build Coastguard Worker super().__init__() 1355*da0073e9SAndroid Build Coastguard Worker self.a = 1 1356*da0073e9SAndroid Build Coastguard Worker self.b = 2 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker def forward(self): 1359*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1362*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1363*da0073e9SAndroid Build Coastguard Worker super().__init__() 1364*da0073e9SAndroid Build Coastguard Worker self.sub1 = SubModule() 1365*da0073e9SAndroid Build Coastguard Worker self.sub2 = SubModule() 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker def forward(self): 1368*da0073e9SAndroid Build Coastguard Worker return self.sub1() + self.sub2() 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1371*da0073e9SAndroid Build Coastguard Worker m.eval() 1372*da0073e9SAndroid Build Coastguard Worker m = torch.jit.freeze(m, preserved_attrs=["sub1.a", "sub2.a"]) 1373*da0073e9SAndroid Build Coastguard Worker fm = m._c 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("sub1")) 1376*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.sub1.hasattr("a")) 1377*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.sub1.hasattr("b")) 1378*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("sub2")) 1379*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.sub2.hasattr("a")) 1380*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.sub2.hasattr("b")) 1381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), 6) 1382*da0073e9SAndroid Build Coastguard Worker m.sub1.a += 1 1383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), 7) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self): 1386*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 1387*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1388*da0073e9SAndroid Build Coastguard Worker super().__init__() 1389*da0073e9SAndroid Build Coastguard Worker self.a = 1 1390*da0073e9SAndroid Build Coastguard Worker self.b = 2 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def forward(self): 1393*da0073e9SAndroid Build Coastguard Worker return self.a + self.b 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1396*da0073e9SAndroid Build Coastguard Worker def method_a(self): 1397*da0073e9SAndroid Build Coastguard Worker return 42 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1400*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1401*da0073e9SAndroid Build Coastguard Worker super().__init__() 1402*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker def forward(self): 1405*da0073e9SAndroid Build Coastguard Worker return 1 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1408*da0073e9SAndroid Build Coastguard Worker m.eval() 1409*da0073e9SAndroid Build Coastguard Worker fm = torch.jit.freeze(m, preserved_attrs=["sub.a", "sub.method_a"])._c 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("sub")) 1412*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.sub.hasattr("a")) 1413*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.sub.hasattr("b")) 1414*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.sub._has_method("method_a")) 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_user_preserved_method_on_submodule(self): 1417*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 1418*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1419*da0073e9SAndroid Build Coastguard Worker return self.method_a(x) + self.method_b(x) 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker def method_a(self, x): 1422*da0073e9SAndroid Build Coastguard Worker return x * x 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def method_b(self, x): 1425*da0073e9SAndroid Build Coastguard Worker return x + x 1426*da0073e9SAndroid Build Coastguard Worker 1427*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 1428*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1429*da0073e9SAndroid Build Coastguard Worker super().__init__() 1430*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1433*da0073e9SAndroid Build Coastguard Worker return self.sub(x) 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Module()) 1436*da0073e9SAndroid Build Coastguard Worker m.eval() 1437*da0073e9SAndroid Build Coastguard Worker fm = torch.jit.freeze(m, preserved_attrs=["sub.method_a"])._c 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.hasattr("sub")) 1440*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fm.sub._has_method("method_a")) 1441*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fm.sub._has_method("method_b")) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker @skipIfNoFBGEMM 1444*da0073e9SAndroid Build Coastguard Worker def test_module_with_shared_type_instances(self): 1445*da0073e9SAndroid Build Coastguard Worker class Child(nn.Module): 1446*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1447*da0073e9SAndroid Build Coastguard Worker super().__init__() 1448*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1451*da0073e9SAndroid Build Coastguard Worker x = self.conv1(x) 1452*da0073e9SAndroid Build Coastguard Worker return x 1453*da0073e9SAndroid Build Coastguard Worker 1454*da0073e9SAndroid Build Coastguard Worker class Parent(nn.Module): 1455*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1456*da0073e9SAndroid Build Coastguard Worker super().__init__() 1457*da0073e9SAndroid Build Coastguard Worker self.quant = torch.ao.quantization.QuantStub() 1458*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32) 1459*da0073e9SAndroid Build Coastguard Worker self.child = Child() 1460*da0073e9SAndroid Build Coastguard Worker self.child2 = Child() 1461*da0073e9SAndroid Build Coastguard Worker self.dequant = torch.ao.quantization.DeQuantStub() 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1464*da0073e9SAndroid Build Coastguard Worker x = self.quant(x) 1465*da0073e9SAndroid Build Coastguard Worker x = self.conv1(x) 1466*da0073e9SAndroid Build Coastguard Worker x = self.child(x) 1467*da0073e9SAndroid Build Coastguard Worker x = self.child2(x) 1468*da0073e9SAndroid Build Coastguard Worker x = self.dequant(x) 1469*da0073e9SAndroid Build Coastguard Worker return x 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker def _static_quant(model): 1472*da0073e9SAndroid Build Coastguard Worker qModel = torch.ao.quantization.QuantWrapper(model) 1473*da0073e9SAndroid Build Coastguard Worker qModel.qconfig = torch.ao.quantization.default_qconfig 1474*da0073e9SAndroid Build Coastguard Worker torch.ao.quantization.prepare(qModel, inplace=True) 1475*da0073e9SAndroid Build Coastguard Worker qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32)) 1476*da0073e9SAndroid Build Coastguard Worker torch.ao.quantization.convert(qModel, inplace=True) 1477*da0073e9SAndroid Build Coastguard Worker return model 1478*da0073e9SAndroid Build Coastguard Worker 1479*da0073e9SAndroid Build Coastguard Worker with override_quantized_engine("fbgemm"): 1480*da0073e9SAndroid Build Coastguard Worker data = torch.randn(4, 1, 4, 4, dtype=torch.float32) 1481*da0073e9SAndroid Build Coastguard Worker m = Parent().to(torch.float32) 1482*da0073e9SAndroid Build Coastguard Worker m = _static_quant(m) 1483*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(m) 1484*da0073e9SAndroid Build Coastguard Worker m.eval() 1485*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(m.graph) 1486*da0073e9SAndroid Build Coastguard Worker m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c)) 1487*da0073e9SAndroid Build Coastguard Worker # Earlier bug resulted in _packed_params set to false. 1488*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("_packed_params = False").run( 1489*da0073e9SAndroid Build Coastguard Worker m_frozen._c.dump_to_str(True, True, False) 1490*da0073e9SAndroid Build Coastguard Worker ) 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker m_res = m(data) 1493*da0073e9SAndroid Build Coastguard Worker # It used to segfault while running frozen module. 1494*da0073e9SAndroid Build Coastguard Worker m_frozen_res = m_frozen(data) 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_res, m_frozen_res) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker def test_module_getattr_indirection(self): 1498*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1499*da0073e9SAndroid Build Coastguard Worker class ValHolder: 1500*da0073e9SAndroid Build Coastguard Worker def __init__(self, val: int): 1501*da0073e9SAndroid Build Coastguard Worker self.val: int = val 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 1504*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1505*da0073e9SAndroid Build Coastguard Worker super().__init__() 1506*da0073e9SAndroid Build Coastguard Worker self.mod1 = ValHolder(1) 1507*da0073e9SAndroid Build Coastguard Worker self.mod2 = ValHolder(2) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker def forward(self, cond: bool): 1510*da0073e9SAndroid Build Coastguard Worker if cond: 1511*da0073e9SAndroid Build Coastguard Worker mod = self.mod1 1512*da0073e9SAndroid Build Coastguard Worker else: 1513*da0073e9SAndroid Build Coastguard Worker mod = self.mod2 1514*da0073e9SAndroid Build Coastguard Worker return mod.val 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker mod = Mod() 1517*da0073e9SAndroid Build Coastguard Worker mod.eval() 1518*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 1519*da0073e9SAndroid Build Coastguard Worker mod_eager = Mod() 1520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(True), frozen_mod(True)) 1521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(False), frozen_mod(False)) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_non_static_module_container_index(self): 1524*da0073e9SAndroid Build Coastguard Worker """ 1525*da0073e9SAndroid Build Coastguard Worker Test that Modules containing non-static ModuleDict or ModuleList 1526*da0073e9SAndroid Build Coastguard Worker indexing cannot be frozen. 1527*da0073e9SAndroid Build Coastguard Worker """ 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1530*da0073e9SAndroid Build Coastguard Worker class ModuleInterface(torch.nn.Module): 1531*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: Any) -> Any: 1532*da0073e9SAndroid Build Coastguard Worker pass 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker class ImplementsInterface(torch.nn.Module): 1535*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: Any) -> Any: 1536*da0073e9SAndroid Build Coastguard Worker if isinstance(inp, torch.Tensor): 1537*da0073e9SAndroid Build Coastguard Worker return torch.max(inp, dim=0) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker return inp 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker class ModWithDict(torch.nn.Module): 1542*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1543*da0073e9SAndroid Build Coastguard Worker super().__init__() 1544*da0073e9SAndroid Build Coastguard Worker self.d = torch.nn.ModuleDict({"module": ImplementsInterface()}) 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, key: str) -> Any: 1547*da0073e9SAndroid Build Coastguard Worker value: ModuleInterface = self.d[key] 1548*da0073e9SAndroid Build Coastguard Worker return value.forward(x) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(ModWithDict()) 1551*da0073e9SAndroid Build Coastguard Worker m.eval() 1552*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1553*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1554*da0073e9SAndroid Build Coastguard Worker "Freezing modules containing prim::ModuleContainerIndex is not supported", 1555*da0073e9SAndroid Build Coastguard Worker ): 1556*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 1557*da0073e9SAndroid Build Coastguard Worker 1558*da0073e9SAndroid Build Coastguard Worker class ModWithList(torch.nn.Module): 1559*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1560*da0073e9SAndroid Build Coastguard Worker super().__init__() 1561*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.ModuleList([ImplementsInterface()]) 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, idx: int) -> Any: 1564*da0073e9SAndroid Build Coastguard Worker value: ModuleInterface = self.l[idx] 1565*da0073e9SAndroid Build Coastguard Worker return value.forward(x) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(ModWithList()) 1568*da0073e9SAndroid Build Coastguard Worker m.eval() 1569*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1570*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1571*da0073e9SAndroid Build Coastguard Worker "Freezing modules containing prim::ModuleContainerIndex is not supported", 1572*da0073e9SAndroid Build Coastguard Worker ): 1573*da0073e9SAndroid Build Coastguard Worker mf = torch._C._freeze_module(m._c) 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker def test_freeze_with_interface_mutable(self): 1576*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1577*da0073e9SAndroid Build Coastguard Worker class ModuleInterface(torch.nn.Module): 1578*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1579*da0073e9SAndroid Build Coastguard Worker pass 1580*da0073e9SAndroid Build Coastguard Worker 1581*da0073e9SAndroid Build Coastguard Worker class ImplementsInterface(torch.nn.Module): 1582*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1583*da0073e9SAndroid Build Coastguard Worker super().__init__() 1584*da0073e9SAndroid Build Coastguard Worker self.sum = torch.zeros((2, 2)) 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1587*da0073e9SAndroid Build Coastguard Worker self.sum += inp.relu() 1588*da0073e9SAndroid Build Coastguard Worker return self.sum 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1591*da0073e9SAndroid Build Coastguard Worker impl: ModuleInterface 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1594*da0073e9SAndroid Build Coastguard Worker super().__init__() 1595*da0073e9SAndroid Build Coastguard Worker self.impl = ImplementsInterface() 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1598*da0073e9SAndroid Build Coastguard Worker return self.impl.forward(x) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(WrapperModule()) 1601*da0073e9SAndroid Build Coastguard Worker m.eval() 1602*da0073e9SAndroid Build Coastguard Worker m_frozen = torch.jit.freeze(m) 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker m_frozen(x) 1607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_frozen.impl.sum, x.relu()) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker def test_freeze_with_swapping_interfaces(self): 1610*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1611*da0073e9SAndroid Build Coastguard Worker class ModuleInterface(torch.nn.Module): 1612*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1613*da0073e9SAndroid Build Coastguard Worker pass 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker class Implementation1(torch.nn.Module): 1616*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1617*da0073e9SAndroid Build Coastguard Worker return inp.relu() 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker class Implementation2(torch.nn.Module): 1620*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1621*da0073e9SAndroid Build Coastguard Worker return inp.sin() 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1624*da0073e9SAndroid Build Coastguard Worker impl: ModuleInterface 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1627*da0073e9SAndroid Build Coastguard Worker super().__init__() 1628*da0073e9SAndroid Build Coastguard Worker self.option1 = Implementation1() 1629*da0073e9SAndroid Build Coastguard Worker self.option2 = Implementation2() 1630*da0073e9SAndroid Build Coastguard Worker self.impl = self.option1 1631*da0073e9SAndroid Build Coastguard Worker self.idx = 0 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1634*da0073e9SAndroid Build Coastguard Worker self.idx += 1 1635*da0073e9SAndroid Build Coastguard Worker if self.idx % 2 == 1: 1636*da0073e9SAndroid Build Coastguard Worker self.impl = self.option1 1637*da0073e9SAndroid Build Coastguard Worker else: 1638*da0073e9SAndroid Build Coastguard Worker self.impl = self.option2 1639*da0073e9SAndroid Build Coastguard Worker return self.impl(x) 1640*da0073e9SAndroid Build Coastguard Worker 1641*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(WrapperModule()) 1642*da0073e9SAndroid Build Coastguard Worker m.eval() 1643*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1644*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Freezing does not support SetAttr on an interface type" 1645*da0073e9SAndroid Build Coastguard Worker ): 1646*da0073e9SAndroid Build Coastguard Worker m_frozen = torch.jit.freeze(m) 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker def test_freeze_recursive_interfaces(self): 1649*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1650*da0073e9SAndroid Build Coastguard Worker class InnerInterface(torch.nn.Module): 1651*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1652*da0073e9SAndroid Build Coastguard Worker pass 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1655*da0073e9SAndroid Build Coastguard Worker class OuterInterface(torch.nn.Module): 1656*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1657*da0073e9SAndroid Build Coastguard Worker pass 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker class InnerImpl(torch.nn.Module): 1660*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1661*da0073e9SAndroid Build Coastguard Worker super().__init__() 1662*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones((2, 2)) 1663*da0073e9SAndroid Build Coastguard Worker 1664*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1665*da0073e9SAndroid Build Coastguard Worker return inp.cos() * self.x 1666*da0073e9SAndroid Build Coastguard Worker 1667*da0073e9SAndroid Build Coastguard Worker class OuterImpl(torch.nn.Module): 1668*da0073e9SAndroid Build Coastguard Worker inner_impl: InnerInterface 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1671*da0073e9SAndroid Build Coastguard Worker super().__init__() 1672*da0073e9SAndroid Build Coastguard Worker self.inner_impl = InnerImpl() 1673*da0073e9SAndroid Build Coastguard Worker 1674*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1675*da0073e9SAndroid Build Coastguard Worker return inp.relu() + self.inner_impl(inp.sin()) 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1678*da0073e9SAndroid Build Coastguard Worker outer_impl: OuterInterface 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1681*da0073e9SAndroid Build Coastguard Worker super().__init__() 1682*da0073e9SAndroid Build Coastguard Worker self.outer_impl = OuterImpl() 1683*da0073e9SAndroid Build Coastguard Worker 1684*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1685*da0073e9SAndroid Build Coastguard Worker return self.outer_impl(inp) + inp 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker m = WrapperModule() 1688*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1689*da0073e9SAndroid Build Coastguard Worker expected = m(x) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1692*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1693*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.freeze(m_s) 1694*da0073e9SAndroid Build Coastguard Worker actual = m_s(x) 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def test_freeze_recursive_interfaces_with_reassignment(self): 1699*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1700*da0073e9SAndroid Build Coastguard Worker class InnerInterface(torch.nn.Module): 1701*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1702*da0073e9SAndroid Build Coastguard Worker pass 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1705*da0073e9SAndroid Build Coastguard Worker class OuterInterface(torch.nn.Module): 1706*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1707*da0073e9SAndroid Build Coastguard Worker pass 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker class InnerImpl1(torch.nn.Module): 1710*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1711*da0073e9SAndroid Build Coastguard Worker super().__init__() 1712*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones((2, 2)) 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1715*da0073e9SAndroid Build Coastguard Worker return inp.cos() * self.x 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker class InnerImpl2(torch.nn.Module): 1718*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1719*da0073e9SAndroid Build Coastguard Worker super().__init__() 1720*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones((2, 2)) * 2 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1723*da0073e9SAndroid Build Coastguard Worker return inp.sin() / self.x 1724*da0073e9SAndroid Build Coastguard Worker 1725*da0073e9SAndroid Build Coastguard Worker class OuterImpl(torch.nn.Module): 1726*da0073e9SAndroid Build Coastguard Worker inner_impl: InnerInterface 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1729*da0073e9SAndroid Build Coastguard Worker super().__init__() 1730*da0073e9SAndroid Build Coastguard Worker self.inner_impl = InnerImpl1() 1731*da0073e9SAndroid Build Coastguard Worker self.impl1 = InnerImpl1() 1732*da0073e9SAndroid Build Coastguard Worker self.impl2 = InnerImpl1() 1733*da0073e9SAndroid Build Coastguard Worker self.idx = 0 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1736*da0073e9SAndroid Build Coastguard Worker self.idx += 1 1737*da0073e9SAndroid Build Coastguard Worker if self.idx % 2 == 0: 1738*da0073e9SAndroid Build Coastguard Worker self.inner_impl = self.impl1 1739*da0073e9SAndroid Build Coastguard Worker else: 1740*da0073e9SAndroid Build Coastguard Worker self.inner_impl = self.impl2 1741*da0073e9SAndroid Build Coastguard Worker return inp.relu() + self.inner_impl(inp.sin()) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1744*da0073e9SAndroid Build Coastguard Worker outer_impl: OuterInterface 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1747*da0073e9SAndroid Build Coastguard Worker super().__init__() 1748*da0073e9SAndroid Build Coastguard Worker self.outer_impl = OuterImpl() 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1751*da0073e9SAndroid Build Coastguard Worker return self.outer_impl(inp) + inp 1752*da0073e9SAndroid Build Coastguard Worker 1753*da0073e9SAndroid Build Coastguard Worker m = WrapperModule() 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1756*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1757*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1758*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Freezing does not support SetAttr on an interface type" 1759*da0073e9SAndroid Build Coastguard Worker ): 1760*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.freeze(m_s) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker def test_freeze_interface_swapping_two_methods(self): 1763*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1764*da0073e9SAndroid Build Coastguard Worker class MyInterface(torch.nn.Module): 1765*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1766*da0073e9SAndroid Build Coastguard Worker pass 1767*da0073e9SAndroid Build Coastguard Worker 1768*da0073e9SAndroid Build Coastguard Worker class Impl1(torch.nn.Module): 1769*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1770*da0073e9SAndroid Build Coastguard Worker return inp.cos() 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker class Impl2(torch.nn.Module): 1773*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1774*da0073e9SAndroid Build Coastguard Worker return inp.sin() 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker class WrapperModule1(torch.nn.Module): 1777*da0073e9SAndroid Build Coastguard Worker interface_impl: MyInterface 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1780*da0073e9SAndroid Build Coastguard Worker super().__init__() 1781*da0073e9SAndroid Build Coastguard Worker self.interface_impl = Impl1() 1782*da0073e9SAndroid Build Coastguard Worker self.impl1 = Impl1() 1783*da0073e9SAndroid Build Coastguard Worker self.impl2 = Impl2() 1784*da0073e9SAndroid Build Coastguard Worker self.idx = 0 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1787*da0073e9SAndroid Build Coastguard Worker return self.interface_impl(x) 1788*da0073e9SAndroid Build Coastguard Worker 1789*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1790*da0073e9SAndroid Build Coastguard Worker def other_method(self, x): 1791*da0073e9SAndroid Build Coastguard Worker self.idx += 1 1792*da0073e9SAndroid Build Coastguard Worker if self.idx % 2 == 0: 1793*da0073e9SAndroid Build Coastguard Worker self.interface_impl = self.impl1 1794*da0073e9SAndroid Build Coastguard Worker else: 1795*da0073e9SAndroid Build Coastguard Worker self.interface_impl = self.impl2 1796*da0073e9SAndroid Build Coastguard Worker return self.interface_impl(x) 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker class WrapperModule2(torch.nn.Module): 1799*da0073e9SAndroid Build Coastguard Worker interface_impl: MyInterface 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1802*da0073e9SAndroid Build Coastguard Worker super().__init__() 1803*da0073e9SAndroid Build Coastguard Worker self.interface_impl = Impl1() 1804*da0073e9SAndroid Build Coastguard Worker self.impl1 = Impl1() 1805*da0073e9SAndroid Build Coastguard Worker self.impl2 = Impl2() 1806*da0073e9SAndroid Build Coastguard Worker self.idx = 0 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1809*da0073e9SAndroid Build Coastguard Worker self.idx += 1 1810*da0073e9SAndroid Build Coastguard Worker if self.idx % 2 == 0: 1811*da0073e9SAndroid Build Coastguard Worker self.interface_impl = self.impl1 1812*da0073e9SAndroid Build Coastguard Worker else: 1813*da0073e9SAndroid Build Coastguard Worker self.interface_impl = self.impl2 1814*da0073e9SAndroid Build Coastguard Worker return self.interface_impl(x) 1815*da0073e9SAndroid Build Coastguard Worker 1816*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 1817*da0073e9SAndroid Build Coastguard Worker def other_method(self, x): 1818*da0073e9SAndroid Build Coastguard Worker return self.interface_impl(x) 1819*da0073e9SAndroid Build Coastguard Worker 1820*da0073e9SAndroid Build Coastguard Worker m1 = torch.jit.script(WrapperModule1()) 1821*da0073e9SAndroid Build Coastguard Worker m2 = torch.jit.script(WrapperModule2()) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker m1.eval() 1824*da0073e9SAndroid Build Coastguard Worker m2.eval() 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1827*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Freezing does not support SetAttr on an interface type" 1828*da0073e9SAndroid Build Coastguard Worker ): 1829*da0073e9SAndroid Build Coastguard Worker torch.jit.freeze(m1, preserved_attrs=["other_method"]) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1832*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Freezing does not support SetAttr on an interface type" 1833*da0073e9SAndroid Build Coastguard Worker ): 1834*da0073e9SAndroid Build Coastguard Worker torch.jit.freeze(m2, preserved_attrs=["other_method"]) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker def test_freeze_recursive_interfaces_same_name(self): 1837*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1838*da0073e9SAndroid Build Coastguard Worker class InnerInterface(torch.nn.Module): 1839*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1840*da0073e9SAndroid Build Coastguard Worker pass 1841*da0073e9SAndroid Build Coastguard Worker 1842*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 1843*da0073e9SAndroid Build Coastguard Worker class OuterInterface(torch.nn.Module): 1844*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1845*da0073e9SAndroid Build Coastguard Worker pass 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker class InnerImpl(torch.nn.Module): 1848*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1849*da0073e9SAndroid Build Coastguard Worker super().__init__() 1850*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones((2, 2)) 1851*da0073e9SAndroid Build Coastguard Worker 1852*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1853*da0073e9SAndroid Build Coastguard Worker return inp.cos() * self.x 1854*da0073e9SAndroid Build Coastguard Worker 1855*da0073e9SAndroid Build Coastguard Worker class OuterImpl(torch.nn.Module): 1856*da0073e9SAndroid Build Coastguard Worker impl: InnerInterface 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1859*da0073e9SAndroid Build Coastguard Worker super().__init__() 1860*da0073e9SAndroid Build Coastguard Worker self.impl = InnerImpl() 1861*da0073e9SAndroid Build Coastguard Worker self.x = torch.ones((2, 2)) * 5 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1864*da0073e9SAndroid Build Coastguard Worker return self.other_method(inp) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker def other_method(self, inp): 1867*da0073e9SAndroid Build Coastguard Worker return inp.relu() + self.impl(inp.sin()) + self.x 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1870*da0073e9SAndroid Build Coastguard Worker impl: OuterInterface 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1873*da0073e9SAndroid Build Coastguard Worker super().__init__() 1874*da0073e9SAndroid Build Coastguard Worker self.impl = OuterImpl() 1875*da0073e9SAndroid Build Coastguard Worker 1876*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 1877*da0073e9SAndroid Build Coastguard Worker return self.impl(inp) + inp 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker m = WrapperModule() 1880*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1881*da0073e9SAndroid Build Coastguard Worker expected = m(x) 1882*da0073e9SAndroid Build Coastguard Worker 1883*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.script(m) 1884*da0073e9SAndroid Build Coastguard Worker m_s.eval() 1885*da0073e9SAndroid Build Coastguard Worker m_s = torch.jit.freeze(m_s) 1886*da0073e9SAndroid Build Coastguard Worker actual = m_s(x) 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker def test_freeze_non_interface_module_swap(self): 1891*da0073e9SAndroid Build Coastguard Worker class InnerModule(torch.nn.Module): 1892*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 1893*da0073e9SAndroid Build Coastguard Worker super().__init__() 1894*da0073e9SAndroid Build Coastguard Worker self.x = x 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1897*da0073e9SAndroid Build Coastguard Worker return inp.relu() + self.x 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1900*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1901*da0073e9SAndroid Build Coastguard Worker super().__init__() 1902*da0073e9SAndroid Build Coastguard Worker self.option1 = InnerModule(torch.rand((2, 2))) 1903*da0073e9SAndroid Build Coastguard Worker self.option2 = InnerModule(torch.rand((2, 2))) 1904*da0073e9SAndroid Build Coastguard Worker self.impl = self.option1 1905*da0073e9SAndroid Build Coastguard Worker self.idx = 0 1906*da0073e9SAndroid Build Coastguard Worker 1907*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1908*da0073e9SAndroid Build Coastguard Worker self.idx += 1 1909*da0073e9SAndroid Build Coastguard Worker if self.idx % 2 == 1: 1910*da0073e9SAndroid Build Coastguard Worker self.impl = self.option1 1911*da0073e9SAndroid Build Coastguard Worker else: 1912*da0073e9SAndroid Build Coastguard Worker self.impl = self.option2 1913*da0073e9SAndroid Build Coastguard Worker return self.impl(x) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker unfrozen = WrapperModule() 1916*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(unfrozen) 1917*da0073e9SAndroid Build Coastguard Worker m.eval() 1918*da0073e9SAndroid Build Coastguard Worker m_frozen = torch.jit.freeze(m) 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1921*da0073e9SAndroid Build Coastguard Worker expected = unfrozen(x) 1922*da0073e9SAndroid Build Coastguard Worker actual = m_frozen(x) 1923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1926*da0073e9SAndroid Build Coastguard Worker def test_freeze_interface_within_object(self): 1927*da0073e9SAndroid Build Coastguard Worker # I don't think there's any way to create a plain python object that 1928*da0073e9SAndroid Build Coastguard Worker # contains a torch.nn.Module inside it, but just in case... I'm not 1929*da0073e9SAndroid Build Coastguard Worker # sure freezing would handle this case correctly, so marking as xfail 1930*da0073e9SAndroid Build Coastguard Worker # so that if this ever _does_ start working someone will need to 1931*da0073e9SAndroid Build Coastguard Worker # investigate to make sure this is handled correctly. 1932*da0073e9SAndroid Build Coastguard Worker class MyIface(torch.nn.Module): 1933*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1934*da0073e9SAndroid Build Coastguard Worker pass 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker class MyImpl(torch.nn.Module): 1937*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: torch.Tensor) -> torch.Tensor: 1938*da0073e9SAndroid Build Coastguard Worker return inp.sin() 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker class MyObject: 1941*da0073e9SAndroid Build Coastguard Worker impl: MyIface 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Worker def run(self, x): 1944*da0073e9SAndroid Build Coastguard Worker return self.impl(x) 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker class WrapperModule(torch.nn.Module): 1947*da0073e9SAndroid Build Coastguard Worker impl: MyObject 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1950*da0073e9SAndroid Build Coastguard Worker super().__init__() 1951*da0073e9SAndroid Build Coastguard Worker self.impl = MyObject() 1952*da0073e9SAndroid Build Coastguard Worker self.impl.impl = MyImpl() 1953*da0073e9SAndroid Build Coastguard Worker 1954*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 1955*da0073e9SAndroid Build Coastguard Worker return self.impl(x) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker unfrozen = WrapperModule() 1958*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(unfrozen) 1959*da0073e9SAndroid Build Coastguard Worker m.eval() 1960*da0073e9SAndroid Build Coastguard Worker m_frozen = torch.jit.freeze(m) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1963*da0073e9SAndroid Build Coastguard Worker expected = unfrozen(x) 1964*da0073e9SAndroid Build Coastguard Worker actual = m_frozen(x) 1965*da0073e9SAndroid Build Coastguard Worker self.expectEqual(expected, actual) 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker def test_freeze_non_module_class_getattr(self): 1968*da0073e9SAndroid Build Coastguard Worker class BoxCoder: 1969*da0073e9SAndroid Build Coastguard Worker def __init__(self, bbox_xform_clip): 1970*da0073e9SAndroid Build Coastguard Worker # type: (float) -> None 1971*da0073e9SAndroid Build Coastguard Worker self.bbox_xform_clip = bbox_xform_clip 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker def decode(self, input): 1974*da0073e9SAndroid Build Coastguard Worker return input * self.bbox_xform_clip 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1977*da0073e9SAndroid Build Coastguard Worker __annotations__ = { 1978*da0073e9SAndroid Build Coastguard Worker "box_coder": BoxCoder, 1979*da0073e9SAndroid Build Coastguard Worker } 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1982*da0073e9SAndroid Build Coastguard Worker super().__init__() 1983*da0073e9SAndroid Build Coastguard Worker self.box_coder = BoxCoder(50.0) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1986*da0073e9SAndroid Build Coastguard Worker return self.box_coder.decode(input) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker model = MyModule() 1989*da0073e9SAndroid Build Coastguard Worker model.eval() 1990*da0073e9SAndroid Build Coastguard Worker script_model = torch.jit.freeze(torch.jit.script(model)) 1991*da0073e9SAndroid Build Coastguard Worker inp = torch.randn([4, 4]) 1992*da0073e9SAndroid Build Coastguard Worker output_eager = model(inp) 1993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(inp), script_model(inp)) 1994*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("GetAttr").run(script_model.graph) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_tupleoutput_submodule(self): 1997*da0073e9SAndroid Build Coastguard Worker class SubModule(nn.Module): 1998*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1999*da0073e9SAndroid Build Coastguard Worker return (x + 1, x + 2) 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 2002*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2003*da0073e9SAndroid Build Coastguard Worker super().__init__() 2004*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2007*da0073e9SAndroid Build Coastguard Worker y1, y2 = self.sub(x) 2008*da0073e9SAndroid Build Coastguard Worker return y1 + y2 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(TestModule()) 2011*da0073e9SAndroid Build Coastguard Worker m = m.eval() 2012*da0073e9SAndroid Build Coastguard Worker mf = torch.jit.freeze(m) 2013*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 2) 2014*da0073e9SAndroid Build Coastguard Worker expected = m.forward(inp) 2015*da0073e9SAndroid Build Coastguard Worker output = mf.forward(inp) 2016*da0073e9SAndroid Build Coastguard Worker # Check if prim::TupleConstruct and prim::TupleUnpack 2017*da0073e9SAndroid Build Coastguard Worker # Don't exist in frozen graph 2018*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::TupleConstruct").run(mf.graph) 2019*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::TupleUnpack").run(mf.graph) 2020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, expected) 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker def test_freeze_module_with_call_method(self): 2023*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 2024*da0073e9SAndroid Build Coastguard Worker def __init__(self, val): 2025*da0073e9SAndroid Build Coastguard Worker super().__init__() 2026*da0073e9SAndroid Build Coastguard Worker self.param = nn.Parameter(val) 2027*da0073e9SAndroid Build Coastguard Worker 2028*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2029*da0073e9SAndroid Build Coastguard Worker # this method will change during freezing 2030*da0073e9SAndroid Build Coastguard Worker return x + self.param 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 2033*da0073e9SAndroid Build Coastguard Worker def make_prediction(self, x): 2034*da0073e9SAndroid Build Coastguard Worker y = x + x 2035*da0073e9SAndroid Build Coastguard Worker return self.forward(y) 2036*da0073e9SAndroid Build Coastguard Worker 2037*da0073e9SAndroid Build Coastguard Worker param = torch.rand([2, 2]) 2038*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker unscripted_mod = Mod(param) 2041*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(unscripted_mod) 2042*da0073e9SAndroid Build Coastguard Worker mod.eval() 2043*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"]) 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2046*da0073e9SAndroid Build Coastguard Worker mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5 2047*da0073e9SAndroid Build Coastguard Worker ) 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("somehow causing hanging during python shutdown") 2051*da0073e9SAndroid Build Coastguard Workerclass TestFrozenOptimizations(JitTestCase): 2052*da0073e9SAndroid Build Coastguard Worker def setUp(self): 2053*da0073e9SAndroid Build Coastguard Worker super().setUp() 2054*da0073e9SAndroid Build Coastguard Worker self.default_dtype = torch.get_default_dtype() 2055*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(torch.double) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 2058*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(self.default_dtype) 2059*da0073e9SAndroid Build Coastguard Worker super().tearDown() 2060*da0073e9SAndroid Build Coastguard Worker 2061*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_folding(self): 2062*da0073e9SAndroid Build Coastguard Worker conv_bias = [True, False] 2063*da0073e9SAndroid Build Coastguard Worker module_pairs = [ 2064*da0073e9SAndroid Build Coastguard Worker (nn.Conv1d, nn.BatchNorm1d), 2065*da0073e9SAndroid Build Coastguard Worker (nn.Conv2d, nn.BatchNorm2d), 2066*da0073e9SAndroid Build Coastguard Worker (nn.Conv3d, nn.BatchNorm3d), 2067*da0073e9SAndroid Build Coastguard Worker ] 2068*da0073e9SAndroid Build Coastguard Worker use_tracing = [True, False] 2069*da0073e9SAndroid Build Coastguard Worker bn_running_stats = [True, False] 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Worker for use_bias, modules, tracing, track_stats in product( 2072*da0073e9SAndroid Build Coastguard Worker conv_bias, module_pairs, use_tracing, bn_running_stats 2073*da0073e9SAndroid Build Coastguard Worker ): 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard Worker class ConvBN(torch.nn.Module): 2076*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 2077*da0073e9SAndroid Build Coastguard Worker super().__init__() 2078*da0073e9SAndroid Build Coastguard Worker self.conv = modules[0]( 2079*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=use_bias, **kwargs 2080*da0073e9SAndroid Build Coastguard Worker ) 2081*da0073e9SAndroid Build Coastguard Worker self.bn = modules[1]( 2082*da0073e9SAndroid Build Coastguard Worker out_channels, eps=0.001, track_running_stats=track_stats 2083*da0073e9SAndroid Build Coastguard Worker ) 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2086*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 2087*da0073e9SAndroid Build Coastguard Worker return self.bn(x) 2088*da0073e9SAndroid Build Coastguard Worker 2089*da0073e9SAndroid Build Coastguard Worker mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() 2090*da0073e9SAndroid Build Coastguard Worker inps = [4, 3, 4] 2091*da0073e9SAndroid Build Coastguard Worker if modules[0] == nn.Conv2d: 2092*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2093*da0073e9SAndroid Build Coastguard Worker if modules[0] == nn.Conv3d: 2094*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2095*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps) 2098*da0073e9SAndroid Build Coastguard Worker 2099*da0073e9SAndroid Build Coastguard Worker if tracing: 2100*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (inp)) 2101*da0073e9SAndroid Build Coastguard Worker else: 2102*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", scripted_mod.graph) 2105*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", scripted_mod.graph) 2106*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", scripted_mod.graph) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check("batch").run(scripted_mod.graph) 2109*da0073e9SAndroid Build Coastguard Worker # successfully no-ops with non-const inputs 2110*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) 2111*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph) 2112*da0073e9SAndroid Build Coastguard Worker 2113*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2114*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_bn", scripted_mod.graph) 2115*da0073e9SAndroid Build Coastguard Worker if track_stats: 2116*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not("aten::batch_norm").run( 2117*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2118*da0073e9SAndroid Build Coastguard Worker ) 2119*da0073e9SAndroid Build Coastguard Worker else: 2120*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check("aten::batch_norm").run( 2121*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2122*da0073e9SAndroid Build Coastguard Worker ) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_folding_not_forward(self): 2128*da0073e9SAndroid Build Coastguard Worker class ConvBN(torch.nn.Module): 2129*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 2130*da0073e9SAndroid Build Coastguard Worker super().__init__() 2131*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 2132*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=True, **kwargs 2133*da0073e9SAndroid Build Coastguard Worker ) 2134*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 2135*da0073e9SAndroid Build Coastguard Worker self.amt = 3.2 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2138*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 2139*da0073e9SAndroid Build Coastguard Worker return self.bn(x) 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 2142*da0073e9SAndroid Build Coastguard Worker def make_prediction(self, x): 2143*da0073e9SAndroid Build Coastguard Worker return self.forward(x) + self.amt 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval() 2146*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2147*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(scripted_mod.make_prediction.graph) 2148*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check("aten::batch_norm").run( 2149*da0073e9SAndroid Build Coastguard Worker scripted_mod.make_prediction.graph 2150*da0073e9SAndroid Build Coastguard Worker ) 2151*da0073e9SAndroid Build Coastguard Worker 2152*da0073e9SAndroid Build Coastguard Worker # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt") 2153*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze( 2154*da0073e9SAndroid Build Coastguard Worker scripted_mod, preserved_attrs=["make_prediction", "amt"] 2155*da0073e9SAndroid Build Coastguard Worker ) 2156*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not("aten::batch_norm").run( 2157*da0073e9SAndroid Build Coastguard Worker scripted_mod.make_prediction.graph 2158*da0073e9SAndroid Build Coastguard Worker ) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker # During freezing this creates tensors constants that are attached to the frozen graph, 2161*da0073e9SAndroid Build Coastguard Worker # which is then kept alive by the compilation unit (which causes a leak) 2162*da0073e9SAndroid Build Coastguard Worker @skipCUDAMemoryLeakCheckIf(True) 2163*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2164*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_folding_autocast_scenario_cuda(self): 2165*da0073e9SAndroid Build Coastguard Worker # CUDA conv takes input tensors which must all be the same dtype, 2166*da0073e9SAndroid Build Coastguard Worker # which can cause issues if folding produces inputs of different dtypes. 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker class ConvBN(torch.nn.Module): 2169*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 2170*da0073e9SAndroid Build Coastguard Worker super().__init__() 2171*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d( 2172*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=False, dtype=torch.half, **kwargs 2173*da0073e9SAndroid Build Coastguard Worker ) 2174*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d( 2175*da0073e9SAndroid Build Coastguard Worker out_channels, eps=0.001, dtype=torch.float 2176*da0073e9SAndroid Build Coastguard Worker ) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2179*da0073e9SAndroid Build Coastguard Worker return self.bn(self.conv(x)) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval() 2182*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2183*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2184*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph) 2185*da0073e9SAndroid Build Coastguard Worker conv_node = scripted_mod.graph.findNode("aten::conv2d", True) 2186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(conv_node is not None) 2187*da0073e9SAndroid Build Coastguard Worker bias_input = conv_node.namedInput("bias") 2188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bias_input is not None) 2189*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bias_input.type().dtype() == torch.half) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda() 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2195*da0073e9SAndroid Build Coastguard Worker 2196*da0073e9SAndroid Build Coastguard Worker def test_conv_add_folding(self): 2197*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2198*da0073e9SAndroid Build Coastguard Worker def test_conv_fusion( 2199*da0073e9SAndroid Build Coastguard Worker use_bias, module, tracing, op, scalar, add_tensor, expect_success 2200*da0073e9SAndroid Build Coastguard Worker ): 2201*da0073e9SAndroid Build Coastguard Worker class ConvOp(torch.nn.Module): 2202*da0073e9SAndroid Build Coastguard Worker __constants__ = ["use_scalar"] 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, tensor=None, **kwargs): 2205*da0073e9SAndroid Build Coastguard Worker super().__init__() 2206*da0073e9SAndroid Build Coastguard Worker self.conv = module( 2207*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=use_bias, **kwargs 2208*da0073e9SAndroid Build Coastguard Worker ) 2209*da0073e9SAndroid Build Coastguard Worker self.conv2 = module( 2210*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=use_bias, **kwargs 2211*da0073e9SAndroid Build Coastguard Worker ) 2212*da0073e9SAndroid Build Coastguard Worker self.use_scalar = scalar 2213*da0073e9SAndroid Build Coastguard Worker tensor_size = [1 for _ in range(self.conv.weight.ndim)] 2214*da0073e9SAndroid Build Coastguard Worker tensor_size[1] = self.conv.weight.size(0) 2215*da0073e9SAndroid Build Coastguard Worker self.tensor = ( 2216*da0073e9SAndroid Build Coastguard Worker add_tensor 2217*da0073e9SAndroid Build Coastguard Worker if add_tensor is not None 2218*da0073e9SAndroid Build Coastguard Worker else torch.rand(tensor_size) 2219*da0073e9SAndroid Build Coastguard Worker ) 2220*da0073e9SAndroid Build Coastguard Worker self.op = op 2221*da0073e9SAndroid Build Coastguard Worker 2222*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2223*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 2224*da0073e9SAndroid Build Coastguard Worker if self.use_scalar: 2225*da0073e9SAndroid Build Coastguard Worker return self.op(x, 2.0) 2226*da0073e9SAndroid Build Coastguard Worker else: 2227*da0073e9SAndroid Build Coastguard Worker return self.op(x, self.tensor) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval() 2230*da0073e9SAndroid Build Coastguard Worker 2231*da0073e9SAndroid Build Coastguard Worker inps = [4, 3, 4] 2232*da0073e9SAndroid Build Coastguard Worker if module == nn.Conv2d: 2233*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2234*da0073e9SAndroid Build Coastguard Worker if module == nn.Conv3d: 2235*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2236*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2237*da0073e9SAndroid Build Coastguard Worker 2238*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps) 2239*da0073e9SAndroid Build Coastguard Worker 2240*da0073e9SAndroid Build Coastguard Worker if tracing: 2241*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (inp,)) 2242*da0073e9SAndroid Build Coastguard Worker else: 2243*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", scripted_mod.graph) 2246*da0073e9SAndroid Build Coastguard Worker op_str = "aten::" + op.__name__ 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2249*da0073e9SAndroid Build Coastguard Worker # successively no-ops with non-const inputs 2250*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) 2251*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) 2252*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2253*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2254*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph) 2255*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker if expect_success: 2258*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph) 2259*da0073e9SAndroid Build Coastguard Worker else: 2260*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check(op_str).run(scripted_mod.graph) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2264*da0073e9SAndroid Build Coastguard Worker 2265*da0073e9SAndroid Build Coastguard Worker conv_bias = [True, False] 2266*da0073e9SAndroid Build Coastguard Worker modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] 2267*da0073e9SAndroid Build Coastguard Worker use_tracing = [False, True] 2268*da0073e9SAndroid Build Coastguard Worker use_scalar = [False, True] 2269*da0073e9SAndroid Build Coastguard Worker ops = [torch.add, torch.sub, torch.mul, torch.div] 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker for use_bias, module, tracing, pytorch_op, scalar in product( 2272*da0073e9SAndroid Build Coastguard Worker conv_bias, modules, use_tracing, ops, use_scalar 2273*da0073e9SAndroid Build Coastguard Worker ): 2274*da0073e9SAndroid Build Coastguard Worker test_conv_fusion( 2275*da0073e9SAndroid Build Coastguard Worker use_bias, 2276*da0073e9SAndroid Build Coastguard Worker module, 2277*da0073e9SAndroid Build Coastguard Worker tracing, 2278*da0073e9SAndroid Build Coastguard Worker pytorch_op, 2279*da0073e9SAndroid Build Coastguard Worker scalar, 2280*da0073e9SAndroid Build Coastguard Worker add_tensor=None, 2281*da0073e9SAndroid Build Coastguard Worker expect_success=True, 2282*da0073e9SAndroid Build Coastguard Worker ) 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker for use_bias, pytorch_op in product(conv_bias, ops): 2285*da0073e9SAndroid Build Coastguard Worker # broadcasting add 2286*da0073e9SAndroid Build Coastguard Worker test_conv_fusion( 2287*da0073e9SAndroid Build Coastguard Worker use_bias, 2288*da0073e9SAndroid Build Coastguard Worker nn.Conv2d, 2289*da0073e9SAndroid Build Coastguard Worker False, 2290*da0073e9SAndroid Build Coastguard Worker pytorch_op, 2291*da0073e9SAndroid Build Coastguard Worker False, 2292*da0073e9SAndroid Build Coastguard Worker add_tensor=torch.rand(32, 1, 32), 2293*da0073e9SAndroid Build Coastguard Worker expect_success=False, 2294*da0073e9SAndroid Build Coastguard Worker ) 2295*da0073e9SAndroid Build Coastguard Worker 2296*da0073e9SAndroid Build Coastguard Worker # broadcasting add 2297*da0073e9SAndroid Build Coastguard Worker test_conv_fusion( 2298*da0073e9SAndroid Build Coastguard Worker use_bias, 2299*da0073e9SAndroid Build Coastguard Worker nn.Conv2d, 2300*da0073e9SAndroid Build Coastguard Worker False, 2301*da0073e9SAndroid Build Coastguard Worker pytorch_op, 2302*da0073e9SAndroid Build Coastguard Worker False, 2303*da0073e9SAndroid Build Coastguard Worker add_tensor=torch.rand(1, 1), 2304*da0073e9SAndroid Build Coastguard Worker expect_success=True, 2305*da0073e9SAndroid Build Coastguard Worker ) 2306*da0073e9SAndroid Build Coastguard Worker 2307*da0073e9SAndroid Build Coastguard Worker # add with different dtype 2308*da0073e9SAndroid Build Coastguard Worker test_conv_fusion( 2309*da0073e9SAndroid Build Coastguard Worker use_bias, 2310*da0073e9SAndroid Build Coastguard Worker nn.Conv2d, 2311*da0073e9SAndroid Build Coastguard Worker False, 2312*da0073e9SAndroid Build Coastguard Worker pytorch_op, 2313*da0073e9SAndroid Build Coastguard Worker False, 2314*da0073e9SAndroid Build Coastguard Worker add_tensor=torch.tensor([2]).to(torch.int), 2315*da0073e9SAndroid Build Coastguard Worker expect_success=True, 2316*da0073e9SAndroid Build Coastguard Worker ) 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker def test_conv_mul_add_bn(self): 2319*da0073e9SAndroid Build Coastguard Worker class Conv_Mul_Add_Bn(nn.Module): 2320*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 2321*da0073e9SAndroid Build Coastguard Worker super().__init__() 2322*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) 2323*da0073e9SAndroid Build Coastguard Worker self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 2324*da0073e9SAndroid Build Coastguard Worker self.tensor1 = torch.tensor(2.2) 2325*da0073e9SAndroid Build Coastguard Worker self.tensor2 = torch.tensor(2) 2326*da0073e9SAndroid Build Coastguard Worker 2327*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2328*da0073e9SAndroid Build Coastguard Worker return self.bn( 2329*da0073e9SAndroid Build Coastguard Worker torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2) 2330*da0073e9SAndroid Build Coastguard Worker ) 2331*da0073e9SAndroid Build Coastguard Worker 2332*da0073e9SAndroid Build Coastguard Worker input = torch.randn(8, 3, 64, 64) 2333*da0073e9SAndroid Build Coastguard Worker model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval() 2334*da0073e9SAndroid Build Coastguard Worker 2335*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2336*da0073e9SAndroid Build Coastguard Worker result = model(input) 2337*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.trace(model, input).eval() 2338*da0073e9SAndroid Build Coastguard Worker traced_model = torch.jit.freeze(traced_model) 2339*da0073e9SAndroid Build Coastguard Worker tresult = traced_model(input) 2340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, tresult) 2341*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not("aten::batch_norm").run( 2342*da0073e9SAndroid Build Coastguard Worker traced_model.graph 2343*da0073e9SAndroid Build Coastguard Worker ) 2344*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check_not("aten::add").run(traced_model.graph) 2345*da0073e9SAndroid Build Coastguard Worker 2346*da0073e9SAndroid Build Coastguard Worker def test_linear_bn_folding(self): 2347*da0073e9SAndroid Build Coastguard Worker module_pairs = [ 2348*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm1d), 2349*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm2d), 2350*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm3d), 2351*da0073e9SAndroid Build Coastguard Worker ] 2352*da0073e9SAndroid Build Coastguard Worker use_tracing = [True, False] 2353*da0073e9SAndroid Build Coastguard Worker bn_running_stats = [True, False] 2354*da0073e9SAndroid Build Coastguard Worker 2355*da0073e9SAndroid Build Coastguard Worker for modules, tracing, track_stats in product( 2356*da0073e9SAndroid Build Coastguard Worker module_pairs, use_tracing, bn_running_stats 2357*da0073e9SAndroid Build Coastguard Worker ): 2358*da0073e9SAndroid Build Coastguard Worker 2359*da0073e9SAndroid Build Coastguard Worker class LinearBN(torch.nn.Module): 2360*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 2361*da0073e9SAndroid Build Coastguard Worker super().__init__() 2362*da0073e9SAndroid Build Coastguard Worker self.linear = modules[0](in_features, out_features) 2363*da0073e9SAndroid Build Coastguard Worker self.bn = modules[1]( 2364*da0073e9SAndroid Build Coastguard Worker out_features, eps=0.001, track_running_stats=track_stats 2365*da0073e9SAndroid Build Coastguard Worker ) 2366*da0073e9SAndroid Build Coastguard Worker 2367*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2368*da0073e9SAndroid Build Coastguard Worker x = self.linear(x) 2369*da0073e9SAndroid Build Coastguard Worker return self.bn(x) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker mod_eager = LinearBN(32, 32).eval() 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker inps = [3, 32] 2374*da0073e9SAndroid Build Coastguard Worker if modules[1] == nn.BatchNorm2d: 2375*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2376*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2377*da0073e9SAndroid Build Coastguard Worker if modules[1] == nn.BatchNorm3d: 2378*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2379*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2380*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2381*da0073e9SAndroid Build Coastguard Worker 2382*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker if tracing: 2385*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (inp)) 2386*da0073e9SAndroid Build Coastguard Worker else: 2387*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", scripted_mod.graph) 2390*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", scripted_mod.graph) 2391*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", scripted_mod.graph) 2392*da0073e9SAndroid Build Coastguard Worker 2393*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("batch").run(scripted_mod.graph) 2394*da0073e9SAndroid Build Coastguard Worker # successfully no-ops with non-const inputs 2395*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2396*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("aten::batch_norm").run( 2397*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2398*da0073e9SAndroid Build Coastguard Worker ) 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2401*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2402*da0073e9SAndroid Build Coastguard Worker if track_stats: 2403*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check_not("aten::batch_norm").run( 2404*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2405*da0073e9SAndroid Build Coastguard Worker ) 2406*da0073e9SAndroid Build Coastguard Worker else: 2407*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("aten::batch_norm").run( 2408*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2409*da0073e9SAndroid Build Coastguard Worker ) 2410*da0073e9SAndroid Build Coastguard Worker 2411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), scripted_mod(inp)) 2413*da0073e9SAndroid Build Coastguard Worker 2414*da0073e9SAndroid Build Coastguard Worker def test_bn_not_broadcast_with_linear(self): 2415*da0073e9SAndroid Build Coastguard Worker module_pairs = [ 2416*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm1d), 2417*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm2d), 2418*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm3d), 2419*da0073e9SAndroid Build Coastguard Worker ] 2420*da0073e9SAndroid Build Coastguard Worker use_tracing = [True, False] 2421*da0073e9SAndroid Build Coastguard Worker linear_in = 3 2422*da0073e9SAndroid Build Coastguard Worker # (linear_out, bn_in) 2423*da0073e9SAndroid Build Coastguard Worker # case 1: linear_out < bn_in 2424*da0073e9SAndroid Build Coastguard Worker # case 2: linear_out > bn_in 2425*da0073e9SAndroid Build Coastguard Worker # case 3: linear_out != bn_in && linear_out = 1 2426*da0073e9SAndroid Build Coastguard Worker dims = [(2, 4), (4, 2), (1, 2)] 2427*da0073e9SAndroid Build Coastguard Worker 2428*da0073e9SAndroid Build Coastguard Worker for modules, tracing, dim in product(module_pairs, use_tracing, dims): 2429*da0073e9SAndroid Build Coastguard Worker linear_out, bn_in = dim[0], dim[1] 2430*da0073e9SAndroid Build Coastguard Worker 2431*da0073e9SAndroid Build Coastguard Worker linear = modules[0](linear_in, linear_out) 2432*da0073e9SAndroid Build Coastguard Worker bn = modules[1](bn_in) 2433*da0073e9SAndroid Build Coastguard Worker mod_eager = nn.Sequential(linear, bn).eval() 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker N, C = 3, bn_in 2436*da0073e9SAndroid Build Coastguard Worker input_shape = [N, C] 2437*da0073e9SAndroid Build Coastguard Worker if modules[1] == nn.BatchNorm1d: 2438*da0073e9SAndroid Build Coastguard Worker H = linear_in 2439*da0073e9SAndroid Build Coastguard Worker input_shape.append(H) 2440*da0073e9SAndroid Build Coastguard Worker elif modules[1] == nn.BatchNorm2d: 2441*da0073e9SAndroid Build Coastguard Worker H, W = 4, linear_in 2442*da0073e9SAndroid Build Coastguard Worker input_shape.append(H) 2443*da0073e9SAndroid Build Coastguard Worker input_shape.append(W) 2444*da0073e9SAndroid Build Coastguard Worker elif modules[1] == nn.BatchNorm3d: 2445*da0073e9SAndroid Build Coastguard Worker D, H, W = 4, 4, linear_in 2446*da0073e9SAndroid Build Coastguard Worker input_shape.append(D) 2447*da0073e9SAndroid Build Coastguard Worker input_shape.append(H) 2448*da0073e9SAndroid Build Coastguard Worker input_shape.append(W) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(input_shape) 2451*da0073e9SAndroid Build Coastguard Worker 2452*da0073e9SAndroid Build Coastguard Worker if tracing: 2453*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (inp)) 2454*da0073e9SAndroid Build Coastguard Worker else: 2455*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", scripted_mod.graph) 2458*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", scripted_mod.graph) 2459*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", scripted_mod.graph) 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("batch").run(scripted_mod.graph) 2462*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_linear_bn", scripted_mod.graph) 2463*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("aten::batch_norm").run( 2464*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2465*da0073e9SAndroid Build Coastguard Worker ) 2466*da0073e9SAndroid Build Coastguard Worker 2467*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(scripted_mod) 2468*da0073e9SAndroid Build Coastguard Worker self.run_pass("fold_frozen_linear_bn", frozen_mod.graph) 2469*da0073e9SAndroid Build Coastguard Worker # successfully skipped folding 2470*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check("aten::batch_norm").run(frozen_mod.graph) 2471*da0073e9SAndroid Build Coastguard Worker 2472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), frozen_mod(inp)) 2473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), frozen_mod(inp)) 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker # successfully failed folding 2476*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2477*da0073e9SAndroid Build Coastguard Worker AssertionError, 2478*da0073e9SAndroid Build Coastguard Worker "To fuse, linear.out_features == bn.num_features or bn.num_features == 1", 2479*da0073e9SAndroid Build Coastguard Worker ): 2480*da0073e9SAndroid Build Coastguard Worker nn.utils.fusion.fuse_linear_bn_eval(linear, bn) 2481*da0073e9SAndroid Build Coastguard Worker 2482*da0073e9SAndroid Build Coastguard Worker @skipCUDAMemoryLeakCheckIf(True) 2483*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2484*da0073e9SAndroid Build Coastguard Worker def test_linear_bn_folding_autocast_scenario_cuda(self): 2485*da0073e9SAndroid Build Coastguard Worker module_pairs = [ 2486*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm1d), 2487*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm2d), 2488*da0073e9SAndroid Build Coastguard Worker (nn.Linear, nn.BatchNorm3d), 2489*da0073e9SAndroid Build Coastguard Worker ] 2490*da0073e9SAndroid Build Coastguard Worker use_tracing = [True, False] 2491*da0073e9SAndroid Build Coastguard Worker bn_running_stats = [True, False] 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker for modules, tracing, track_stats in product( 2494*da0073e9SAndroid Build Coastguard Worker module_pairs, use_tracing, bn_running_stats 2495*da0073e9SAndroid Build Coastguard Worker ): 2496*da0073e9SAndroid Build Coastguard Worker 2497*da0073e9SAndroid Build Coastguard Worker class LinearBN(torch.nn.Module): 2498*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 2499*da0073e9SAndroid Build Coastguard Worker super().__init__() 2500*da0073e9SAndroid Build Coastguard Worker self.linear = modules[0]( 2501*da0073e9SAndroid Build Coastguard Worker in_features, out_features, bias=False, dtype=torch.half 2502*da0073e9SAndroid Build Coastguard Worker ) 2503*da0073e9SAndroid Build Coastguard Worker self.bn = modules[1](out_features, eps=0.001, dtype=torch.float) 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2506*da0073e9SAndroid Build Coastguard Worker x = self.linear(x) 2507*da0073e9SAndroid Build Coastguard Worker return self.bn(x) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker mod_eager = LinearBN(32, 32).cuda().eval() 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker inps = [3, 32] 2512*da0073e9SAndroid Build Coastguard Worker if modules[1] == nn.BatchNorm2d: 2513*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2514*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2515*da0073e9SAndroid Build Coastguard Worker if modules[1] == nn.BatchNorm3d: 2516*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2517*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2518*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2519*da0073e9SAndroid Build Coastguard Worker 2520*da0073e9SAndroid Build Coastguard Worker x = torch.rand(inps, dtype=torch.half).cuda() 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker if tracing: 2523*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (x)) 2524*da0073e9SAndroid Build Coastguard Worker else: 2525*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 2526*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2527*da0073e9SAndroid Build Coastguard Worker FileCheck().check("linear").check_not("aten::batch_norm").run( 2528*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2529*da0073e9SAndroid Build Coastguard Worker ) 2530*da0073e9SAndroid Build Coastguard Worker lin_node = scripted_mod.graph.findNode("aten::linear", True) 2531*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lin_node is not None) 2532*da0073e9SAndroid Build Coastguard Worker weight_input = lin_node.namedInput("weight") 2533*da0073e9SAndroid Build Coastguard Worker bias_input = lin_node.namedInput("bias") 2534*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bias_input is not None) 2535*da0073e9SAndroid Build Coastguard Worker self.assertTrue(weight_input.type().dtype() == torch.half) 2536*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bias_input.type().dtype() == torch.half) 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2) 2540*da0073e9SAndroid Build Coastguard Worker 2541*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2542*da0073e9SAndroid Build Coastguard Worker def test_linear_concat(self): 2543*da0073e9SAndroid Build Coastguard Worker out_dimms = [[5, 10], [1, 5]] 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker for w1_dim, w2_dim in out_dimms: 2546*da0073e9SAndroid Build Coastguard Worker 2547*da0073e9SAndroid Build Coastguard Worker class ModMultLinear(nn.Module): 2548*da0073e9SAndroid Build Coastguard Worker def __init__(self, w1_dim, w2_dim): 2549*da0073e9SAndroid Build Coastguard Worker super().__init__() 2550*da0073e9SAndroid Build Coastguard Worker self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2551*da0073e9SAndroid Build Coastguard Worker self.b1 = nn.Parameter(torch.rand([w1_dim])) 2552*da0073e9SAndroid Build Coastguard Worker self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2553*da0073e9SAndroid Build Coastguard Worker self.b2 = nn.Parameter(torch.rand([w2_dim])) 2554*da0073e9SAndroid Build Coastguard Worker 2555*da0073e9SAndroid Build Coastguard Worker def forward(self, in_tensor1): 2556*da0073e9SAndroid Build Coastguard Worker res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2557*da0073e9SAndroid Build Coastguard Worker res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) 2558*da0073e9SAndroid Build Coastguard Worker return res1, res2 2559*da0073e9SAndroid Build Coastguard Worker 2560*da0073e9SAndroid Build Coastguard Worker mod_eager = ModMultLinear(w1_dim, w2_dim).eval() 2561*da0073e9SAndroid Build Coastguard Worker 2562*da0073e9SAndroid Build Coastguard Worker test_val1 = torch.rand([50, 5]) 2563*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations(mod_eager, 2, 1, (test_val1,)) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2566*da0073e9SAndroid Build Coastguard Worker def test_linear_concat_complex(self): 2567*da0073e9SAndroid Build Coastguard Worker """ 2568*da0073e9SAndroid Build Coastguard Worker Testing that the interleaving of multiple optimizations does not 2569*da0073e9SAndroid Build Coastguard Worker cause errors, and gets optimized as expected 2570*da0073e9SAndroid Build Coastguard Worker """ 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker class ModMultLinear(nn.Module): 2573*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2574*da0073e9SAndroid Build Coastguard Worker super().__init__() 2575*da0073e9SAndroid Build Coastguard Worker w1_dim = 5 2576*da0073e9SAndroid Build Coastguard Worker w2_dim = 10 2577*da0073e9SAndroid Build Coastguard Worker self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2578*da0073e9SAndroid Build Coastguard Worker self.b1 = nn.Parameter(torch.rand([w1_dim])) 2579*da0073e9SAndroid Build Coastguard Worker self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2580*da0073e9SAndroid Build Coastguard Worker self.b2 = nn.Parameter(torch.rand([w2_dim])) 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker def forward(self, in_tensor1): 2583*da0073e9SAndroid Build Coastguard Worker res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2584*da0073e9SAndroid Build Coastguard Worker res3 = torch._C._nn.linear(res1, self.w2, self.b2) 2585*da0073e9SAndroid Build Coastguard Worker res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2) 2586*da0073e9SAndroid Build Coastguard Worker res4 = torch._C._nn.linear(res1, self.w1, self.b1) 2587*da0073e9SAndroid Build Coastguard Worker return res2, res3, res4 2588*da0073e9SAndroid Build Coastguard Worker 2589*da0073e9SAndroid Build Coastguard Worker mod_eager = ModMultLinear().eval() 2590*da0073e9SAndroid Build Coastguard Worker test_val1 = torch.rand([50, 5]) 2591*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations(mod_eager, 4, 2, (test_val1,)) 2592*da0073e9SAndroid Build Coastguard Worker 2593*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2594*da0073e9SAndroid Build Coastguard Worker def test_linear_concat_different_input(self): 2595*da0073e9SAndroid Build Coastguard Worker """ 2596*da0073e9SAndroid Build Coastguard Worker There should be no change to the graph due to the optimization pass 2597*da0073e9SAndroid Build Coastguard Worker due to the two input tensors being different 2598*da0073e9SAndroid Build Coastguard Worker """ 2599*da0073e9SAndroid Build Coastguard Worker 2600*da0073e9SAndroid Build Coastguard Worker # Freezing requires that the graph be a module 2601*da0073e9SAndroid Build Coastguard Worker class ModMultLinear(nn.Module): 2602*da0073e9SAndroid Build Coastguard Worker def __init__(self, w1_dim, w2_dim): 2603*da0073e9SAndroid Build Coastguard Worker super().__init__() 2604*da0073e9SAndroid Build Coastguard Worker self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2605*da0073e9SAndroid Build Coastguard Worker self.b1 = nn.Parameter(torch.rand([w1_dim])) 2606*da0073e9SAndroid Build Coastguard Worker self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2607*da0073e9SAndroid Build Coastguard Worker self.b2 = nn.Parameter(torch.rand([w2_dim])) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker def forward(self, in_tensor1, in_tensor2): 2610*da0073e9SAndroid Build Coastguard Worker res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2611*da0073e9SAndroid Build Coastguard Worker res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) 2612*da0073e9SAndroid Build Coastguard Worker return res1, res2 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker mod_eager = ModMultLinear(5, 5).eval() 2615*da0073e9SAndroid Build Coastguard Worker test_val1 = torch.rand([50, 5]) 2616*da0073e9SAndroid Build Coastguard Worker test_val2 = torch.rand([50, 5]) 2617*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2)) 2618*da0073e9SAndroid Build Coastguard Worker 2619*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU") 2620*da0073e9SAndroid Build Coastguard Worker def test_linear_multiple_blocks(self): 2621*da0073e9SAndroid Build Coastguard Worker class ModMultLinear(nn.Module): 2622*da0073e9SAndroid Build Coastguard Worker def __init__(self, w1_dim, w2_dim): 2623*da0073e9SAndroid Build Coastguard Worker super().__init__() 2624*da0073e9SAndroid Build Coastguard Worker self.w1 = nn.Parameter(torch.rand([w1_dim, 5])) 2625*da0073e9SAndroid Build Coastguard Worker self.b1 = nn.Parameter(torch.rand([w1_dim])) 2626*da0073e9SAndroid Build Coastguard Worker self.w2 = nn.Parameter(torch.rand([w2_dim, 5])) 2627*da0073e9SAndroid Build Coastguard Worker self.b2 = nn.Parameter(torch.rand([w2_dim])) 2628*da0073e9SAndroid Build Coastguard Worker 2629*da0073e9SAndroid Build Coastguard Worker def forward(self, in_tensor1, in_tensor2, cond: bool): 2630*da0073e9SAndroid Build Coastguard Worker res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1) 2631*da0073e9SAndroid Build Coastguard Worker if cond: 2632*da0073e9SAndroid Build Coastguard Worker res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2) 2633*da0073e9SAndroid Build Coastguard Worker res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) 2634*da0073e9SAndroid Build Coastguard Worker else: 2635*da0073e9SAndroid Build Coastguard Worker raise AssertionError 2636*da0073e9SAndroid Build Coastguard Worker res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1) 2637*da0073e9SAndroid Build Coastguard Worker return res1, res2, res3, res4 2638*da0073e9SAndroid Build Coastguard Worker 2639*da0073e9SAndroid Build Coastguard Worker mod_eager = ModMultLinear(5, 5).eval() 2640*da0073e9SAndroid Build Coastguard Worker test_val1 = torch.rand([50, 5]) 2641*da0073e9SAndroid Build Coastguard Worker test_val2 = torch.rand([50, 5]) 2642*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True)) 2643*da0073e9SAndroid Build Coastguard Worker 2644*da0073e9SAndroid Build Coastguard Worker def check_linear_optimizations( 2645*da0073e9SAndroid Build Coastguard Worker self, eager_mod, orig_linears, new_linears, test_vals 2646*da0073e9SAndroid Build Coastguard Worker ): 2647*da0073e9SAndroid Build Coastguard Worker for is_cuda in [False, True]: 2648*da0073e9SAndroid Build Coastguard Worker if is_cuda: 2649*da0073e9SAndroid Build Coastguard Worker mod_to_device = eager_mod.cuda() 2650*da0073e9SAndroid Build Coastguard Worker test_vals_to_device = [ 2651*da0073e9SAndroid Build Coastguard Worker t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals 2652*da0073e9SAndroid Build Coastguard Worker ] 2653*da0073e9SAndroid Build Coastguard Worker else: 2654*da0073e9SAndroid Build Coastguard Worker mod_to_device = eager_mod 2655*da0073e9SAndroid Build Coastguard Worker test_vals_to_device = test_vals 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker script_mod = torch.jit.script(mod_to_device) 2658*da0073e9SAndroid Build Coastguard Worker op_graph = script_mod.graph 2659*da0073e9SAndroid Build Coastguard Worker 2660*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2661*da0073e9SAndroid Build Coastguard Worker op_graph 2662*da0073e9SAndroid Build Coastguard Worker ) 2663*da0073e9SAndroid Build Coastguard Worker # successively no-ops with non-const inputs 2664*da0073e9SAndroid Build Coastguard Worker self.run_pass("concat_frozen_linear", op_graph) 2665*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2666*da0073e9SAndroid Build Coastguard Worker op_graph 2667*da0073e9SAndroid Build Coastguard Worker ) 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker script_mod = torch.jit.freeze(script_mod) 2670*da0073e9SAndroid Build Coastguard Worker op_graph = script_mod.graph 2671*da0073e9SAndroid Build Coastguard Worker self.run_pass("concat_frozen_linear", op_graph) 2672*da0073e9SAndroid Build Coastguard Worker if is_cuda: 2673*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", new_linears, exactly=True).run( 2674*da0073e9SAndroid Build Coastguard Worker op_graph 2675*da0073e9SAndroid Build Coastguard Worker ) 2676*da0073e9SAndroid Build Coastguard Worker else: 2677*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2678*da0073e9SAndroid Build Coastguard Worker op_graph 2679*da0073e9SAndroid Build Coastguard Worker ) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2682*da0073e9SAndroid Build Coastguard Worker mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device) 2683*da0073e9SAndroid Build Coastguard Worker ) 2684*da0073e9SAndroid Build Coastguard Worker 2685*da0073e9SAndroid Build Coastguard Worker def test_optimize_freeze_module(self): 2686*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels = 3, 32 2687*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv2d( 2688*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, kernel_size=3, stride=2, bias=True 2689*da0073e9SAndroid Build Coastguard Worker ) 2690*da0073e9SAndroid Build Coastguard Worker bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) 2691*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Sequential(conv, bn) 2692*da0073e9SAndroid Build Coastguard Worker # set optimize to False here, by default freezing runs run_frozen_optimizations 2693*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze( 2694*da0073e9SAndroid Build Coastguard Worker torch.jit.script(mod.eval()), optimize_numerics=False 2695*da0073e9SAndroid Build Coastguard Worker ) 2696*da0073e9SAndroid Build Coastguard Worker # inspect frozen mod 2697*da0073e9SAndroid Build Coastguard Worker FileCheck().check("batch_norm").run(frozen_mod.graph) 2698*da0073e9SAndroid Build Coastguard Worker torch.jit.run_frozen_optimizations(frozen_mod) 2699*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("batch_norm").run(frozen_mod.graph) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker # run_frozen_optimizations should be run 2702*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval())) 2703*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("batch_norm").run(frozen_mod.graph) 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Worker def test_freeze_remove_dropout(self): 2706*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2707*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2708*da0073e9SAndroid Build Coastguard Worker super().__init__() 2709*da0073e9SAndroid Build Coastguard Worker self.dropout = nn.Dropout(0.5) 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2712*da0073e9SAndroid Build Coastguard Worker return self.dropout(x) 2713*da0073e9SAndroid Build Coastguard Worker 2714*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(Net()) 2715*da0073e9SAndroid Build Coastguard Worker # inspect mod 2716*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(mod.graph) 2717*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::dropout").run(mod.graph) 2718*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(mod.eval()) 2719*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::dropout").run(frozen_mod.graph) 2720*da0073e9SAndroid Build Coastguard Worker 2721*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2) 2722*da0073e9SAndroid Build Coastguard Worker output_s = mod.forward(input) 2723*da0073e9SAndroid Build Coastguard Worker output_f = frozen_mod.forward(input) 2724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 2725*da0073e9SAndroid Build Coastguard Worker 2726*da0073e9SAndroid Build Coastguard Worker def test_freeze_remove_feature_dropout(self): 2727*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2728*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2729*da0073e9SAndroid Build Coastguard Worker super().__init__() 2730*da0073e9SAndroid Build Coastguard Worker self.dropout = nn.Dropout2d(0.5) 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2733*da0073e9SAndroid Build Coastguard Worker return self.dropout(x) 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(Net().eval()) 2736*da0073e9SAndroid Build Coastguard Worker # inspect mod 2737*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline(mod.graph) 2738*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::feature_dropout").run(mod.graph) 2739*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(mod) 2740*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph) 2741*da0073e9SAndroid Build Coastguard Worker 2742*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2, 1, 1) 2743*da0073e9SAndroid Build Coastguard Worker output_s = mod.forward(input) 2744*da0073e9SAndroid Build Coastguard Worker output_f = frozen_mod.forward(input) 2745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_s, output_f) 2746*da0073e9SAndroid Build Coastguard Worker 2747*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2748*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2749*da0073e9SAndroid Build Coastguard Worker ) 2750*da0073e9SAndroid Build Coastguard Worker def test_freeze_mkdlnn(self): 2751*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float() 2752*da0073e9SAndroid Build Coastguard Worker convmkl = mkldnn_utils.to_mkldnn(conv) 2753*da0073e9SAndroid Build Coastguard Worker out = torch.jit.freeze(torch.jit.script(convmkl.eval())) 2754*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([4, 3, 4, 4]).float() 2755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp)) 2756*da0073e9SAndroid Build Coastguard Worker 2757*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2758*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2759*da0073e9SAndroid Build Coastguard Worker ) 2760*da0073e9SAndroid Build Coastguard Worker def test_conv_to_mkldnn(self): 2761*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2762*da0073e9SAndroid Build Coastguard Worker for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]): 2763*da0073e9SAndroid Build Coastguard Worker mod = module(3, 32, kernel_size=3, stride=2).eval() 2764*da0073e9SAndroid Build Coastguard Worker inps = [4, 3, 4] 2765*da0073e9SAndroid Build Coastguard Worker if module == nn.Conv2d: 2766*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2767*da0073e9SAndroid Build Coastguard Worker if module == nn.Conv3d: 2768*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2769*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 2770*da0073e9SAndroid Build Coastguard Worker 2771*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps) 2772*da0073e9SAndroid Build Coastguard Worker if trace: 2773*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod) 2774*da0073e9SAndroid Build Coastguard Worker else: 2775*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod, (inp,)) 2776*da0073e9SAndroid Build Coastguard Worker 2777*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", scripted_mod.graph) 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").run(scripted_mod.graph) 2780*da0073e9SAndroid Build Coastguard Worker # successfully no-ops with non-const inputs 2781*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2782*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("to_mkldnn").run(scripted_mod.graph) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2785*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2786*da0073e9SAndroid Build Coastguard Worker FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check( 2787*da0073e9SAndroid Build Coastguard Worker "to_dense" 2788*da0073e9SAndroid Build Coastguard Worker ).run(scripted_mod.graph) 2789*da0073e9SAndroid Build Coastguard Worker 2790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(inp), scripted_mod(inp)) 2791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(inp), scripted_mod(inp)) 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker def test_linear_transpose(self): 2794*da0073e9SAndroid Build Coastguard Worker class ModLinear(torch.nn.Module): 2795*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2796*da0073e9SAndroid Build Coastguard Worker super().__init__() 2797*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(30)) 2798*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand([30, 20])) 2799*da0073e9SAndroid Build Coastguard Worker 2800*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2801*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.linear(x, self.weight, self.bias) 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker mod_eager = ModLinear().eval() 2804*da0073e9SAndroid Build Coastguard Worker test_val = torch.rand([50, 20]) 2805*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations_2( 2806*da0073e9SAndroid Build Coastguard Worker mod_eager, 1, 0, "transpose_frozen_linear", (test_val,) 2807*da0073e9SAndroid Build Coastguard Worker ) 2808*da0073e9SAndroid Build Coastguard Worker 2809*da0073e9SAndroid Build Coastguard Worker def test_linear_non_constant_weight(self): 2810*da0073e9SAndroid Build Coastguard Worker class ModLinear(torch.nn.Module): 2811*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2812*da0073e9SAndroid Build Coastguard Worker super().__init__() 2813*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.rand(30)) 2814*da0073e9SAndroid Build Coastguard Worker 2815*da0073e9SAndroid Build Coastguard Worker def forward(self, x, weight): 2816*da0073e9SAndroid Build Coastguard Worker return torch._C._nn.linear(x, weight, self.bias) 2817*da0073e9SAndroid Build Coastguard Worker 2818*da0073e9SAndroid Build Coastguard Worker mod_eager = ModLinear().eval() 2819*da0073e9SAndroid Build Coastguard Worker test_val = torch.rand([50, 20]) 2820*da0073e9SAndroid Build Coastguard Worker test_weight = torch.rand([30, 20]) 2821*da0073e9SAndroid Build Coastguard Worker self.check_linear_optimizations_2( 2822*da0073e9SAndroid Build Coastguard Worker mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight) 2823*da0073e9SAndroid Build Coastguard Worker ) 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker def check_linear_optimizations_2( 2826*da0073e9SAndroid Build Coastguard Worker self, eager_mod, orig_linears, new_linears, opt_pass, test_vals 2827*da0073e9SAndroid Build Coastguard Worker ): 2828*da0073e9SAndroid Build Coastguard Worker # TODO: merge with check_linear_optimizations once both diffs land 2829*da0073e9SAndroid Build Coastguard Worker mod_to_device = eager_mod 2830*da0073e9SAndroid Build Coastguard Worker test_vals_to_device = test_vals 2831*da0073e9SAndroid Build Coastguard Worker 2832*da0073e9SAndroid Build Coastguard Worker script_mod = torch.jit.script(mod_to_device) 2833*da0073e9SAndroid Build Coastguard Worker op_graph = script_mod.graph 2834*da0073e9SAndroid Build Coastguard Worker 2835*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2836*da0073e9SAndroid Build Coastguard Worker op_graph 2837*da0073e9SAndroid Build Coastguard Worker ) 2838*da0073e9SAndroid Build Coastguard Worker # successively no-ops with non-const inputs 2839*da0073e9SAndroid Build Coastguard Worker self.run_pass(opt_pass, op_graph) 2840*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", orig_linears, exactly=True).run( 2841*da0073e9SAndroid Build Coastguard Worker op_graph 2842*da0073e9SAndroid Build Coastguard Worker ) 2843*da0073e9SAndroid Build Coastguard Worker 2844*da0073e9SAndroid Build Coastguard Worker script_mod = torch.jit.freeze(script_mod) 2845*da0073e9SAndroid Build Coastguard Worker op_graph = script_mod.graph 2846*da0073e9SAndroid Build Coastguard Worker self.run_pass(opt_pass, op_graph) 2847*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph) 2848*da0073e9SAndroid Build Coastguard Worker 2849*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2850*da0073e9SAndroid Build Coastguard Worker mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device) 2851*da0073e9SAndroid Build Coastguard Worker ) 2852*da0073e9SAndroid Build Coastguard Worker 2853*da0073e9SAndroid Build Coastguard Worker @staticmethod 2854*da0073e9SAndroid Build Coastguard Worker def conv(): 2855*da0073e9SAndroid Build Coastguard Worker # Generic composable conv for testing purposes 2856*da0073e9SAndroid Build Coastguard Worker return nn.Conv2d(8, 8, 1) 2857*da0073e9SAndroid Build Coastguard Worker 2858*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2859*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2860*da0073e9SAndroid Build Coastguard Worker ) 2861*da0073e9SAndroid Build Coastguard Worker def test_collapse_adjacent_conversions(self): 2862*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2863*da0073e9SAndroid Build Coastguard Worker mod = nn.Sequential(self.conv(), self.conv()).eval() 2864*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod) 2865*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2866*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2867*da0073e9SAndroid Build Coastguard Worker FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check( 2868*da0073e9SAndroid Build Coastguard Worker "prim::mkldnn_convolution" 2869*da0073e9SAndroid Build Coastguard Worker ).check("to_dense").run(scripted_mod.graph) 2870*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("to_mkldnn", 1, exactly=True).run( 2871*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2872*da0073e9SAndroid Build Coastguard Worker ) 2873*da0073e9SAndroid Build Coastguard Worker 2874*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([1, 8, 8, 8]) 2875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2876*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2877*da0073e9SAndroid Build Coastguard Worker 2878*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2879*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2880*da0073e9SAndroid Build Coastguard Worker ) 2881*da0073e9SAndroid Build Coastguard Worker def test_mkldnn_fuser_broadcasting(self): 2882*da0073e9SAndroid Build Coastguard Worker class Add(nn.Module): 2883*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 2884*da0073e9SAndroid Build Coastguard Worker super().__init__() 2885*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 2886*da0073e9SAndroid Build Coastguard Worker 2887*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2888*da0073e9SAndroid Build Coastguard Worker return x + self.tensor 2889*da0073e9SAndroid Build Coastguard Worker 2890*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2891*da0073e9SAndroid Build Coastguard Worker for add_inp in [8], [8, 8, 1]: 2892*da0073e9SAndroid Build Coastguard Worker mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval() 2893*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod) 2894*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2895*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2896*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::BroadcastMKLDNNTensors").run( 2897*da0073e9SAndroid Build Coastguard Worker scripted_mod.graph 2898*da0073e9SAndroid Build Coastguard Worker ) 2899*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([1, 8, 8, 8]) 2900*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2902*da0073e9SAndroid Build Coastguard Worker 2903*da0073e9SAndroid Build Coastguard Worker # for good measure, check that broadcasting does not work without this op 2904*da0073e9SAndroid Build Coastguard Worker # so we can remove the op if it ever gets supported 2905*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ""): 2906*da0073e9SAndroid Build Coastguard Worker ( 2907*da0073e9SAndroid Build Coastguard Worker torch.rand([1, 8, 8, 8]).to_mkldnn() 2908*da0073e9SAndroid Build Coastguard Worker + torch.rand(add_inp).to_mkldnn() 2909*da0073e9SAndroid Build Coastguard Worker ) 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2912*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2913*da0073e9SAndroid Build Coastguard Worker ) 2914*da0073e9SAndroid Build Coastguard Worker def test_mkldnn_inplace_removal(self): 2915*da0073e9SAndroid Build Coastguard Worker class AddMul(nn.Module): 2916*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 2917*da0073e9SAndroid Build Coastguard Worker super().__init__() 2918*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 2919*da0073e9SAndroid Build Coastguard Worker 2920*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2921*da0073e9SAndroid Build Coastguard Worker return x.add_(self.tensor).div_(self.tensor) - 4 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2924*da0073e9SAndroid Build Coastguard Worker mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval() 2925*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod) 2926*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.freeze(scripted_mod) 2927*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph) 2928*da0073e9SAndroid Build Coastguard Worker # add gets uninplaced and reinplaced 2929*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::to_mkldnn").check("aten::add_").check( 2930*da0073e9SAndroid Build Coastguard Worker "aten::div_" 2931*da0073e9SAndroid Build Coastguard Worker ).run(scripted_mod.graph) 2932*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([1, 8, 8, 8]) 2933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_mod(inp), mod(inp)) 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 2937*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 2938*da0073e9SAndroid Build Coastguard Worker ) 2939*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 2940*da0073e9SAndroid Build Coastguard Worker def test_maxpool_mkldnn(self): 2941*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2942*da0073e9SAndroid Build Coastguard Worker model = torchvision.models.resnet18() 2943*da0073e9SAndroid Build Coastguard Worker sub_model = torch.nn.Sequential( 2944*da0073e9SAndroid Build Coastguard Worker model.conv1, model.bn1, model.relu, model.maxpool 2945*da0073e9SAndroid Build Coastguard Worker ) 2946*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(sub_model.eval())) 2947*da0073e9SAndroid Build Coastguard Worker ( 2948*da0073e9SAndroid Build Coastguard Worker N, 2949*da0073e9SAndroid Build Coastguard Worker C, 2950*da0073e9SAndroid Build Coastguard Worker H, 2951*da0073e9SAndroid Build Coastguard Worker W, 2952*da0073e9SAndroid Build Coastguard Worker ) = ( 2953*da0073e9SAndroid Build Coastguard Worker 10, 2954*da0073e9SAndroid Build Coastguard Worker 3, 2955*da0073e9SAndroid Build Coastguard Worker 224, 2956*da0073e9SAndroid Build Coastguard Worker 224, 2957*da0073e9SAndroid Build Coastguard Worker ) 2958*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 2959*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 2960*da0073e9SAndroid Build Coastguard Worker FileCheck().check("max_pool").check("to_dense").run(mod.graph) 2961*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph) 2962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(inp), sub_model(inp)) 2963*da0073e9SAndroid Build Coastguard Worker 2964*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn") 2965*da0073e9SAndroid Build Coastguard Worker def test_conv_to_mkldnn_no_mkldnn(self): 2966*da0073e9SAndroid Build Coastguard Worker # test no error when mkldnn not available 2967*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2968*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()) 2969*da0073e9SAndroid Build Coastguard Worker frozen = torch.jit.freeze(mod) 2970*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph) 2971*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([4, 3, 4, 4]) 2972*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frozen(inp), mod(inp)) 2973*da0073e9SAndroid Build Coastguard Worker 2974*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") 2975*da0073e9SAndroid Build Coastguard Worker def test_freeze_conv_relu_fusion(self): 2976*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 2977*da0073e9SAndroid Build Coastguard Worker conv_bias = [True, False] 2978*da0073e9SAndroid Build Coastguard Worker conv_ops = [nn.Conv2d, nn.Conv3d] 2979*da0073e9SAndroid Build Coastguard Worker use_add_z = [True, False] 2980*da0073e9SAndroid Build Coastguard Worker use_tracing = [True, False] 2981*da0073e9SAndroid Build Coastguard Worker for use_bias, conv, add_z, tracing in product( 2982*da0073e9SAndroid Build Coastguard Worker conv_bias, conv_ops, use_add_z, use_tracing 2983*da0073e9SAndroid Build Coastguard Worker ): 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 2986*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 2987*da0073e9SAndroid Build Coastguard Worker super().__init__() 2988*da0073e9SAndroid Build Coastguard Worker self.conv = conv( 2989*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=use_bias, **kwargs 2990*da0073e9SAndroid Build Coastguard Worker ) 2991*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU(inplace=True) 2992*da0073e9SAndroid Build Coastguard Worker self.add_z = add_z 2993*da0073e9SAndroid Build Coastguard Worker 2994*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2995*da0073e9SAndroid Build Coastguard Worker z = self.conv(x) 2996*da0073e9SAndroid Build Coastguard Worker out = self.conv(x) 2997*da0073e9SAndroid Build Coastguard Worker if self.add_z: 2998*da0073e9SAndroid Build Coastguard Worker out += z 2999*da0073e9SAndroid Build Coastguard Worker out = self.relu(out) 3000*da0073e9SAndroid Build Coastguard Worker return out 3001*da0073e9SAndroid Build Coastguard Worker 3002*da0073e9SAndroid Build Coastguard Worker mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker inps = [5, 3, 4, 4] 3005*da0073e9SAndroid Build Coastguard Worker if conv == nn.Conv3d: 3006*da0073e9SAndroid Build Coastguard Worker inps.append(inps[-1]) 3007*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps).cuda() 3008*da0073e9SAndroid Build Coastguard Worker 3009*da0073e9SAndroid Build Coastguard Worker if tracing: 3010*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.trace(mod_eager, (inp)) 3011*da0073e9SAndroid Build Coastguard Worker else: 3012*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 3013*da0073e9SAndroid Build Coastguard Worker 3014*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.optimize_for_inference(scripted_mod) 3015*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 3016*da0073e9SAndroid Build Coastguard Worker if add_z: 3017*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::miopen_convolution_add_relu").run( 3018*da0073e9SAndroid Build Coastguard Worker frozen_mod.graph 3019*da0073e9SAndroid Build Coastguard Worker ) 3020*da0073e9SAndroid Build Coastguard Worker else: 3021*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::miopen_convolution_relu").run( 3022*da0073e9SAndroid Build Coastguard Worker frozen_mod.graph 3023*da0073e9SAndroid Build Coastguard Worker ) 3024*da0073e9SAndroid Build Coastguard Worker else: 3025*da0073e9SAndroid Build Coastguard Worker if add_z: 3026*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::cudnn_convolution_add_relu").run( 3027*da0073e9SAndroid Build Coastguard Worker frozen_mod.graph 3028*da0073e9SAndroid Build Coastguard Worker ) 3029*da0073e9SAndroid Build Coastguard Worker else: 3030*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::cudnn_convolution_relu").run( 3031*da0073e9SAndroid Build Coastguard Worker frozen_mod.graph 3032*da0073e9SAndroid Build Coastguard Worker ) 3033*da0073e9SAndroid Build Coastguard Worker 3034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod_eager(inp), frozen_mod(inp)) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN") 3037*da0073e9SAndroid Build Coastguard Worker def test_freeze_conv_relu_fusion_not_forward(self): 3038*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3039*da0073e9SAndroid Build Coastguard Worker 3040*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 3041*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 3042*da0073e9SAndroid Build Coastguard Worker super().__init__() 3043*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d( 3044*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, bias=None, **kwargs 3045*da0073e9SAndroid Build Coastguard Worker ) 3046*da0073e9SAndroid Build Coastguard Worker self.relu = nn.ReLU(inplace=True) 3047*da0073e9SAndroid Build Coastguard Worker 3048*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3049*da0073e9SAndroid Build Coastguard Worker z = self.conv(x) 3050*da0073e9SAndroid Build Coastguard Worker out = self.conv(x) 3051*da0073e9SAndroid Build Coastguard Worker out = self.relu(out) 3052*da0073e9SAndroid Build Coastguard Worker return out 3053*da0073e9SAndroid Build Coastguard Worker 3054*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 3055*da0073e9SAndroid Build Coastguard Worker def make_prediction(self, x): 3056*da0073e9SAndroid Build Coastguard Worker return self.forward(x) 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda() 3059*da0073e9SAndroid Build Coastguard Worker 3060*da0073e9SAndroid Build Coastguard Worker inps = [5, 3, 4, 4] 3061*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(inps).cuda() 3062*da0073e9SAndroid Build Coastguard Worker 3063*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(mod_eager) 3064*da0073e9SAndroid Build Coastguard Worker 3065*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze( 3066*da0073e9SAndroid Build Coastguard Worker scripted_mod, preserved_attrs=["make_prediction"] 3067*da0073e9SAndroid Build Coastguard Worker ) 3068*da0073e9SAndroid Build Coastguard Worker optimized_mod = torch.jit.optimize_for_inference( 3069*da0073e9SAndroid Build Coastguard Worker frozen_mod, other_methods=["make_prediction"] 3070*da0073e9SAndroid Build Coastguard Worker ) 3071*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM: 3072*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::miopen_convolution_relu").run( 3073*da0073e9SAndroid Build Coastguard Worker optimized_mod.make_prediction.graph 3074*da0073e9SAndroid Build Coastguard Worker ) 3075*da0073e9SAndroid Build Coastguard Worker else: 3076*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::cudnn_convolution_relu").run( 3077*da0073e9SAndroid Build Coastguard Worker optimized_mod.make_prediction.graph 3078*da0073e9SAndroid Build Coastguard Worker ) 3079*da0073e9SAndroid Build Coastguard Worker 3080*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3081*da0073e9SAndroid Build Coastguard Worker mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp) 3082*da0073e9SAndroid Build Coastguard Worker ) 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3085*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3086*da0073e9SAndroid Build Coastguard Worker ) 3087*da0073e9SAndroid Build Coastguard Worker def test_numel_less_than_size_with_padding(self): 3088*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3089*da0073e9SAndroid Build Coastguard Worker 3090*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 3091*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3092*da0073e9SAndroid Build Coastguard Worker super().__init__() 3093*da0073e9SAndroid Build Coastguard Worker self.conv1 = nn.Conv2d( 3094*da0073e9SAndroid Build Coastguard Worker 1, 3095*da0073e9SAndroid Build Coastguard Worker 2, 3096*da0073e9SAndroid Build Coastguard Worker kernel_size=(2, 4), 3097*da0073e9SAndroid Build Coastguard Worker stride=2, 3098*da0073e9SAndroid Build Coastguard Worker padding=2, 3099*da0073e9SAndroid Build Coastguard Worker dilation=(2, 1), 3100*da0073e9SAndroid Build Coastguard Worker ) 3101*da0073e9SAndroid Build Coastguard Worker 3102*da0073e9SAndroid Build Coastguard Worker def forward(self, i0): 3103*da0073e9SAndroid Build Coastguard Worker x = self.conv1(i0) 3104*da0073e9SAndroid Build Coastguard Worker o0 = torch.max(x, i0) 3105*da0073e9SAndroid Build Coastguard Worker o1 = torch.clip(x, -1.5, 1.5) 3106*da0073e9SAndroid Build Coastguard Worker return o0, o1 3107*da0073e9SAndroid Build Coastguard Worker 3108*da0073e9SAndroid Build Coastguard Worker i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32) 3109*da0073e9SAndroid Build Coastguard Worker mod = MyModule() 3110*da0073e9SAndroid Build Coastguard Worker out = mod(i0) 3111*da0073e9SAndroid Build Coastguard Worker 3112*da0073e9SAndroid Build Coastguard Worker exported = torch.jit.trace(mod, [i0]) 3113*da0073e9SAndroid Build Coastguard Worker exported = torch.jit.optimize_for_inference(exported) 3114*da0073e9SAndroid Build Coastguard Worker 3115*da0073e9SAndroid Build Coastguard Worker eout = exported(i0) 3116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout))) 3117*da0073e9SAndroid Build Coastguard Worker 3118*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3119*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3120*da0073e9SAndroid Build Coastguard Worker ) 3121*da0073e9SAndroid Build Coastguard Worker def test_incompatible_perf_formats(self): 3122*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3123*da0073e9SAndroid Build Coastguard Worker 3124*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3125*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3126*da0073e9SAndroid Build Coastguard Worker super().__init__() 3127*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 64, 3, 2) 3128*da0073e9SAndroid Build Coastguard Worker self.max_pool = torch.nn.MaxPool2d(111, 111) 3129*da0073e9SAndroid Build Coastguard Worker 3130*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3131*da0073e9SAndroid Build Coastguard Worker a = self.conv(x) 3132*da0073e9SAndroid Build Coastguard Worker b = self.max_pool(a) 3133*da0073e9SAndroid Build Coastguard Worker return a + b 3134*da0073e9SAndroid Build Coastguard Worker 3135*da0073e9SAndroid Build Coastguard Worker model = Mod() 3136*da0073e9SAndroid Build Coastguard Worker model.eval() 3137*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(model)) 3138*da0073e9SAndroid Build Coastguard Worker ( 3139*da0073e9SAndroid Build Coastguard Worker N, 3140*da0073e9SAndroid Build Coastguard Worker C, 3141*da0073e9SAndroid Build Coastguard Worker H, 3142*da0073e9SAndroid Build Coastguard Worker W, 3143*da0073e9SAndroid Build Coastguard Worker ) = ( 3144*da0073e9SAndroid Build Coastguard Worker 10, 3145*da0073e9SAndroid Build Coastguard Worker 3, 3146*da0073e9SAndroid Build Coastguard Worker 224, 3147*da0073e9SAndroid Build Coastguard Worker 224, 3148*da0073e9SAndroid Build Coastguard Worker ) 3149*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 3150*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(inp), mod(inp)) 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3154*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3155*da0073e9SAndroid Build Coastguard Worker ) 3156*da0073e9SAndroid Build Coastguard Worker def test_pool2d_batchnorm(self): 3157*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3158*da0073e9SAndroid Build Coastguard Worker pooling_layers = [ 3159*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveAvgPool2d(4), 3160*da0073e9SAndroid Build Coastguard Worker # torch.nn.AdaptiveMaxPool2d(4), # return tuples 3161*da0073e9SAndroid Build Coastguard Worker torch.nn.MaxPool2d(4), 3162*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool2d(4), 3163*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm2d(64).eval(), 3164*da0073e9SAndroid Build Coastguard Worker ] 3165*da0073e9SAndroid Build Coastguard Worker 3166*da0073e9SAndroid Build Coastguard Worker for pl in pooling_layers: 3167*da0073e9SAndroid Build Coastguard Worker sub_model = torch.nn.Sequential( 3168*da0073e9SAndroid Build Coastguard Worker torch.nn.Conv2d(3, 64, 2, 2), 3169*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3170*da0073e9SAndroid Build Coastguard Worker pl, 3171*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardswish(), 3172*da0073e9SAndroid Build Coastguard Worker ) 3173*da0073e9SAndroid Build Coastguard Worker sub_model.eval() 3174*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(sub_model)) 3175*da0073e9SAndroid Build Coastguard Worker ( 3176*da0073e9SAndroid Build Coastguard Worker N, 3177*da0073e9SAndroid Build Coastguard Worker C, 3178*da0073e9SAndroid Build Coastguard Worker H, 3179*da0073e9SAndroid Build Coastguard Worker W, 3180*da0073e9SAndroid Build Coastguard Worker ) = ( 3181*da0073e9SAndroid Build Coastguard Worker 10, 3182*da0073e9SAndroid Build Coastguard Worker 3, 3183*da0073e9SAndroid Build Coastguard Worker 224, 3184*da0073e9SAndroid Build Coastguard Worker 224, 3185*da0073e9SAndroid Build Coastguard Worker ) 3186*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 3187*da0073e9SAndroid Build Coastguard Worker # these two passes needed to remove 3188*da0073e9SAndroid Build Coastguard Worker # a size check in BatchNorm2d 3189*da0073e9SAndroid Build Coastguard Worker removeExceptions(mod.graph) 3190*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", mod.graph) 3191*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3192*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) 3193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub_model(inp), mod(inp)) 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3196*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3197*da0073e9SAndroid Build Coastguard Worker ) 3198*da0073e9SAndroid Build Coastguard Worker def test_pool3d_batchnorm(self): 3199*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3200*da0073e9SAndroid Build Coastguard Worker pooling_layers = [ 3201*da0073e9SAndroid Build Coastguard Worker torch.nn.MaxPool3d(4), 3202*da0073e9SAndroid Build Coastguard Worker # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings 3203*da0073e9SAndroid Build Coastguard Worker # torch.nn.AdaptiveMaxPool3d(4), # return tuples 3204*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool3d(4), 3205*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm3d(64).eval(), 3206*da0073e9SAndroid Build Coastguard Worker ] 3207*da0073e9SAndroid Build Coastguard Worker 3208*da0073e9SAndroid Build Coastguard Worker for pl in pooling_layers: 3209*da0073e9SAndroid Build Coastguard Worker sub_model = torch.nn.Sequential( 3210*da0073e9SAndroid Build Coastguard Worker torch.nn.Conv3d(3, 64, 2, 2), 3211*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 3212*da0073e9SAndroid Build Coastguard Worker pl, 3213*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardswish(), 3214*da0073e9SAndroid Build Coastguard Worker ) 3215*da0073e9SAndroid Build Coastguard Worker sub_model.eval() 3216*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(sub_model)) 3217*da0073e9SAndroid Build Coastguard Worker N, C, H, W, D = 10, 3, 64, 64, 64 3218*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, D, H, W) 3219*da0073e9SAndroid Build Coastguard Worker # these two passes needed to remove 3220*da0073e9SAndroid Build Coastguard Worker # a size check in BatchNorm2d 3221*da0073e9SAndroid Build Coastguard Worker removeExceptions(mod.graph) 3222*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", mod.graph) 3223*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3224*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::to_dense").check_next("return").run(mod.graph) 3225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub_model(inp), mod(inp)) 3226*da0073e9SAndroid Build Coastguard Worker 3227*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3228*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3229*da0073e9SAndroid Build Coastguard Worker ) 3230*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 3231*da0073e9SAndroid Build Coastguard Worker def test_conv_hardswish(self): 3232*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker class Clamp(torch.nn.Module): 3235*da0073e9SAndroid Build Coastguard Worker def __init__(self, min_val, max_val, **kwargs): 3236*da0073e9SAndroid Build Coastguard Worker super().__init__() 3237*da0073e9SAndroid Build Coastguard Worker self.min_val = min_val 3238*da0073e9SAndroid Build Coastguard Worker self.max_val = max_val 3239*da0073e9SAndroid Build Coastguard Worker 3240*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3241*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x, self.min_val, self.max_val) 3242*da0073e9SAndroid Build Coastguard Worker 3243*da0073e9SAndroid Build Coastguard Worker ( 3244*da0073e9SAndroid Build Coastguard Worker N, 3245*da0073e9SAndroid Build Coastguard Worker C, 3246*da0073e9SAndroid Build Coastguard Worker H, 3247*da0073e9SAndroid Build Coastguard Worker W, 3248*da0073e9SAndroid Build Coastguard Worker ) = ( 3249*da0073e9SAndroid Build Coastguard Worker 10, 3250*da0073e9SAndroid Build Coastguard Worker 3, 3251*da0073e9SAndroid Build Coastguard Worker 224, 3252*da0073e9SAndroid Build Coastguard Worker 224, 3253*da0073e9SAndroid Build Coastguard Worker ) 3254*da0073e9SAndroid Build Coastguard Worker activations = [ 3255*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardswish(), 3256*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardsigmoid(), 3257*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU6(), 3258*da0073e9SAndroid Build Coastguard Worker torch.nn.Tanh(), 3259*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardtanh(0.0, 6.0), 3260*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardtanh(1.0, 100.0), 3261*da0073e9SAndroid Build Coastguard Worker torch.nn.Hardtanh(-100.0, -1.0), 3262*da0073e9SAndroid Build Coastguard Worker torch.nn.GELU(), 3263*da0073e9SAndroid Build Coastguard Worker Clamp(-100.0, -1.0), 3264*da0073e9SAndroid Build Coastguard Worker Clamp(1.0, 100.0), 3265*da0073e9SAndroid Build Coastguard Worker Clamp(0.0, 6.0), 3266*da0073e9SAndroid Build Coastguard Worker Clamp(-1.0, 0.0), 3267*da0073e9SAndroid Build Coastguard Worker ] 3268*da0073e9SAndroid Build Coastguard Worker 3269*da0073e9SAndroid Build Coastguard Worker model = torchvision.models.resnet18() 3270*da0073e9SAndroid Build Coastguard Worker for activation in activations: 3271*da0073e9SAndroid Build Coastguard Worker sub_model = torch.nn.Sequential(model.conv1, activation) 3272*da0073e9SAndroid Build Coastguard Worker sub_model.eval() 3273*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(sub_model)) 3274*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 3275*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3276*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::to_dense", 1, exactly=True).run( 3277*da0073e9SAndroid Build Coastguard Worker mod.graph 3278*da0073e9SAndroid Build Coastguard Worker ) 3279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub_model(inp), mod(inp)) 3280*da0073e9SAndroid Build Coastguard Worker 3281*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3282*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3283*da0073e9SAndroid Build Coastguard Worker ) 3284*da0073e9SAndroid Build Coastguard Worker def test_hardswish_hardsigmoid(self): 3285*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3286*da0073e9SAndroid Build Coastguard Worker op_map = { 3287*da0073e9SAndroid Build Coastguard Worker "prim::MKLDNNHardSwish": F.hardswish, 3288*da0073e9SAndroid Build Coastguard Worker "prim::MKLDNNHardSigmoid": F.hardsigmoid, 3289*da0073e9SAndroid Build Coastguard Worker } 3290*da0073e9SAndroid Build Coastguard Worker 3291*da0073e9SAndroid Build Coastguard Worker input_sizes = ([0], [1], [3], [1, 3, 8, 8]) 3292*da0073e9SAndroid Build Coastguard Worker for mkldnn_opname, aten_op in op_map.items(): 3293*da0073e9SAndroid Build Coastguard Worker for size in input_sizes: 3294*da0073e9SAndroid Build Coastguard Worker for inplace in (True, False): 3295*da0073e9SAndroid Build Coastguard Worker inplace_str = "_" if inplace else "" 3296*da0073e9SAndroid Build Coastguard Worker inplace_tgt = "%34" if inplace else "%35" 3297*da0073e9SAndroid Build Coastguard Worker graph_str = f"""graph(%input.1 : Tensor): 3298*da0073e9SAndroid Build Coastguard Worker %33 : None = prim::Constant() 3299*da0073e9SAndroid Build Coastguard Worker %34 : Tensor = aten::to_mkldnn(%input.1, %33) 3300*da0073e9SAndroid Build Coastguard Worker %35 : Tensor = {mkldnn_opname}{inplace_str}(%34) 3301*da0073e9SAndroid Build Coastguard Worker return ({inplace_tgt}) 3302*da0073e9SAndroid Build Coastguard Worker """ 3303*da0073e9SAndroid Build Coastguard Worker g = torch._C.parse_ir(graph_str) 3304*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 3305*da0073e9SAndroid Build Coastguard Worker x = torch.rand(size) 3306*da0073e9SAndroid Build Coastguard Worker # `inplace=False` is intentional, otherwise we modify the input 3307*da0073e9SAndroid Build Coastguard Worker # and we aren't testing aten impls anyways 3308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aten_op(x, inplace=False), m(x).to_dense()) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3311*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 3312*da0073e9SAndroid Build Coastguard Worker ) 3313*da0073e9SAndroid Build Coastguard Worker def test_scalar_mul(self): 3314*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 3315*da0073e9SAndroid Build Coastguard Worker 3316*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3317*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3318*da0073e9SAndroid Build Coastguard Worker super().__init__() 3319*da0073e9SAndroid Build Coastguard Worker self.mod = nn.Conv2d(8, 8, 1, padding=1) 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3322*da0073e9SAndroid Build Coastguard Worker a1 = self.mod(x) * 4 3323*da0073e9SAndroid Build Coastguard Worker return a1 * 4 + a1 * 5.0 3324*da0073e9SAndroid Build Coastguard Worker 3325*da0073e9SAndroid Build Coastguard Worker mod = Mod().eval() 3326*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.freeze(torch.jit.script(mod)) 3327*da0073e9SAndroid Build Coastguard Worker optimized = torch.jit.optimize_for_inference(scripted) 3328*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([1, 8, 8, 8]) 3329*da0073e9SAndroid Build Coastguard Worker # a1 cant be inplaced for first use, can for second 3330*da0073e9SAndroid Build Coastguard Worker FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph) 3331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(optimized(inp), mod(inp)) 3332*da0073e9SAndroid Build Coastguard Worker 3333*da0073e9SAndroid Build Coastguard Worker def test_remove_detach(self): 3334*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3335*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3336*da0073e9SAndroid Build Coastguard Worker y = x.detach() 3337*da0073e9SAndroid Build Coastguard Worker return y * y 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Worker mod = Mod().eval() 3340*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 3341*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((2, 2)) 3342*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::detach").run(frozen_mod.graph) 3343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frozen_mod(inp), mod(inp)) 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker def test_remove_detach_not_applied(self): 3346*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3347*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3348*da0073e9SAndroid Build Coastguard Worker y = x.detach() 3349*da0073e9SAndroid Build Coastguard Worker return x is y 3350*da0073e9SAndroid Build Coastguard Worker 3351*da0073e9SAndroid Build Coastguard Worker mod = Mod().eval() 3352*da0073e9SAndroid Build Coastguard Worker frozen_mod = torch.jit.freeze(torch.jit.script(mod)) 3353*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((2, 2)) 3354*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::detach").run(frozen_mod.graph) 3355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frozen_mod(inp), mod(inp)) 3356*da0073e9SAndroid Build Coastguard Worker 3357*da0073e9SAndroid Build Coastguard Worker 3358*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("somehow causing hanging during python shutdown") 3359*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") 3360*da0073e9SAndroid Build Coastguard Workerclass TestMKLDNNReinplacing(JitTestCase): 3361*da0073e9SAndroid Build Coastguard Worker def setUp(self): 3362*da0073e9SAndroid Build Coastguard Worker super().setUp() 3363*da0073e9SAndroid Build Coastguard Worker self.default_dtype = torch.get_default_dtype() 3364*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(torch.float) 3365*da0073e9SAndroid Build Coastguard Worker 3366*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 3367*da0073e9SAndroid Build Coastguard Worker super().tearDown() 3368*da0073e9SAndroid Build Coastguard Worker torch.set_default_dtype(self.default_dtype) 3369*da0073e9SAndroid Build Coastguard Worker 3370*da0073e9SAndroid Build Coastguard Worker def getConv(self): 3371*da0073e9SAndroid Build Coastguard Worker return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval() 3372*da0073e9SAndroid Build Coastguard Worker 3373*da0073e9SAndroid Build Coastguard Worker def getInput(self): 3374*da0073e9SAndroid Build Coastguard Worker return torch.rand([4, 3, 4, 4]) 3375*da0073e9SAndroid Build Coastguard Worker 3376*da0073e9SAndroid Build Coastguard Worker def freezeAndConvert(self, mod): 3377*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.freeze(torch.jit.script(mod.eval())) 3378*da0073e9SAndroid Build Coastguard Worker self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph) 3379*da0073e9SAndroid Build Coastguard Worker return mod 3380*da0073e9SAndroid Build Coastguard Worker 3381*da0073e9SAndroid Build Coastguard Worker def checkResults(self, mod1, mod2): 3382*da0073e9SAndroid Build Coastguard Worker inp = self.getInput() 3383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod1(inp), mod2(inp)) 3384*da0073e9SAndroid Build Coastguard Worker 3385*da0073e9SAndroid Build Coastguard Worker def test_successful(self): 3386*da0073e9SAndroid Build Coastguard Worker # simple conv-relu 3387*da0073e9SAndroid Build Coastguard Worker 3388*da0073e9SAndroid Build Coastguard Worker mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU()) 3389*da0073e9SAndroid Build Coastguard Worker mod = self.freezeAndConvert(mod_eager) 3390*da0073e9SAndroid Build Coastguard Worker FileCheck().check("mkldnn_convolution").check_next( 3391*da0073e9SAndroid Build Coastguard Worker "prim::MKLDNNHardSwish_" 3392*da0073e9SAndroid Build Coastguard Worker ).check_next("aten::relu_").run(mod.graph) 3393*da0073e9SAndroid Build Coastguard Worker self.checkResults(mod_eager, mod) 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker def test_merge_liveness(self): 3396*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3397*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 3398*da0073e9SAndroid Build Coastguard Worker super().__init__() 3399*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 3400*da0073e9SAndroid Build Coastguard Worker 3401*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3402*da0073e9SAndroid Build Coastguard Worker # this mul can be inplaced since x is dead after this use 3403*da0073e9SAndroid Build Coastguard Worker temporary = x * self.tensor 3404*da0073e9SAndroid Build Coastguard Worker # temporary livespan is the return node, 3405*da0073e9SAndroid Build Coastguard Worker # add can not be inplaced 3406*da0073e9SAndroid Build Coastguard Worker return temporary + temporary, temporary 3407*da0073e9SAndroid Build Coastguard Worker 3408*da0073e9SAndroid Build Coastguard Worker mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3409*da0073e9SAndroid Build Coastguard Worker mod = self.freezeAndConvert(mod_eager) 3410*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph) 3411*da0073e9SAndroid Build Coastguard Worker self.checkResults(mod_eager, mod) 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker def test_always_alive_values(self): 3414*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3415*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 3416*da0073e9SAndroid Build Coastguard Worker super().__init__() 3417*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 3418*da0073e9SAndroid Build Coastguard Worker 3419*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3420*da0073e9SAndroid Build Coastguard Worker # x can't be inplaced because its a return value, 3421*da0073e9SAndroid Build Coastguard Worker # check that the inplacing pass doesnt try to inplace 3422*da0073e9SAndroid Build Coastguard Worker # self.tensor because its always alive 3423*da0073e9SAndroid Build Coastguard Worker return x * self.tensor, x 3424*da0073e9SAndroid Build Coastguard Worker 3425*da0073e9SAndroid Build Coastguard Worker mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3426*da0073e9SAndroid Build Coastguard Worker mod = self.freezeAndConvert(mod_eager) 3427*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::mul_").run(mod.graph) 3428*da0073e9SAndroid Build Coastguard Worker self.checkResults(mod_eager, mod) 3429*da0073e9SAndroid Build Coastguard Worker 3430*da0073e9SAndroid Build Coastguard Worker conv = self.getConv() 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3433*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3434*da0073e9SAndroid Build Coastguard Worker super().__init__() 3435*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.rand([4, 32, 1, 1]) 3436*da0073e9SAndroid Build Coastguard Worker self.conv = conv 3437*da0073e9SAndroid Build Coastguard Worker 3438*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3439*da0073e9SAndroid Build Coastguard Worker # the shapes dont add up on this just testing a particular pattern 3440*da0073e9SAndroid Build Coastguard Worker conv_output = self.conv(x) 3441*da0073e9SAndroid Build Coastguard Worker return conv_output, self.conv(torch.add(x, x)) 3442*da0073e9SAndroid Build Coastguard Worker 3443*da0073e9SAndroid Build Coastguard Worker mod = self.freezeAndConvert(Mod()) 3444*da0073e9SAndroid Build Coastguard Worker # x is an input to the graph, and so it should not be inplaced 3445*da0073e9SAndroid Build Coastguard Worker # in the torch.add(x, x) call 3446*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add_").run(mod.graph) 3447*da0073e9SAndroid Build Coastguard Worker 3448*da0073e9SAndroid Build Coastguard Worker def test_switch_inputs_to_inplace(self): 3449*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 3450*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 3451*da0073e9SAndroid Build Coastguard Worker super().__init__() 3452*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3455*da0073e9SAndroid Build Coastguard Worker # self.tensor cannot be inplaced, however x can, 3456*da0073e9SAndroid Build Coastguard Worker # and bc add is commutative we can reverse inputs to add_ 3457*da0073e9SAndroid Build Coastguard Worker return self.tensor + x 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1]))) 3460*da0073e9SAndroid Build Coastguard Worker mod = self.freezeAndConvert(mod_eager) 3461*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::add_").run(mod.graph) 3462*da0073e9SAndroid Build Coastguard Worker self.checkResults(mod_eager, mod) 3463